Reptile元学习算法复现实战:在Omniglot数据集上的少样本学习探索
在深度学习领域中,元学习(Meta-Learning)一直是一个充满挑战性的研究方向。最近我尝试复现了OpenAI提出的Reptile算法,这是一个相对简单但有效的一阶元学习方法。虽然最终的实验结果与论文原始数据存在一定差距,但这个复现过程让我对元学习有了更深入的理解。
初识Reptile算法
当我第一次接触到《On First-Order Meta-Learning Algorithms》这篇论文时,就被Reptile算法的简洁性所吸引。与需要计算二阶导数的MAML算法相比,Reptile的核心思想异常简单:在每个任务上进行几步梯度下降,然后将模型参数朝着任务优化后的参数方向移动。这种"先学习再调整"的策略,让模型能够快速适应新的任务。
元学习的魅力在于它试图解决一个根本性问题:如何让机器像人类一样快速学习新知识。人类在看到几个新字符的样本后,往往能够快速识别相似的字符,这正是少样本学习想要达到的效果。Reptile算法通过在多个相关任务上的训练,让模型学会一个良好的初始化参数,使其能够通过少量梯度步骤快速适应新任务。
代码架构设计
整个项目的代码结构相当清晰,主要由四个Python文件组成。utils.py
提供了一些基础的工具函数,包括文件列表获取和最新检查点查找等功能。models.py
实现了Reptile算法的核心模型类,其中最关键的是point_grad_to
方法,它将梯度设置为当前模型与目标模型参数的差值,这正是Reptile算法的精髓所在。
omniglot.py
负责数据集的处理,这个文件让我印象深刻的地方在于它不仅实现了数据加载,还包含了自动下载Omniglot数据集的功能。当我第一次运行代码时,系统自动从GitHub下载了background和evaluation两个数据集,并智能地合并了数据结构。这种自动化的设计大大简化了环境搭建的复杂度。
最后的train_omniglot.py
是整个训练流程的核心,包含了完整的元学习训练循环。值得注意的是,代码还贴心地提供了TensorBoard支持和断点恢复功能,这在长时间训练中非常有用。
# utils.py
import os
import re# Those two functions are taken from torchvision code because they are not available on pip as of 0.2.0
def list_dir(root, prefix=False):"""List all directories at a given rootArgs:root (str): Path to directory whose folders need to be listedprefix (bool, optional): If true, prepends the path to each result, otherwiseonly returns the name of the directories found"""root = os.path.expanduser(root)directories = list(filter(lambda p: os.path.isdir(os.path.join(root, p)),os.listdir(root)))if prefix is True:directories = [os.path.join(root, d) for d in directories]return directoriesdef list_files(root, suffix, prefix=False):"""List all files ending with a suffix at a given rootArgs:root (str): Path to directory whose folders need to be listedsuffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').It uses the Python "str.endswith" method and is passed directlyprefix (bool, optional): If true, prepends the path to each result, otherwiseonly returns the name of the files found"""root = os.path.expanduser(root)files = list(filter(lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),os.listdir(root)))if prefix is True:files = [os.path.join(root, d) for d in files]return filesdef find_latest_file(folder):files = []for fname in os.listdir(folder):s = re.findall(r'\d+', fname)if len(s) == 1:files.append((int(s[0]), fname))if files:return max(files)[1]else:return None
# models.py
import torch
from torch import nnclass ReptileModel(nn.Module):def __init__(self):nn.Module.__init__(self)def point_grad_to(self, target):'''Set .grad attribute of each parameter to be proportionalto the difference between self and target'''for p, target_p in zip(self.parameters(), target.parameters()):if p.grad is None:if self.is_cuda():p.grad = torch.zeros(p.size(), device=p.device)else:p.grad = torch.zeros(p.size())p.grad.data.zero_() # not sure this is requiredp.grad.data.add_(p.data - target_p.data)def is_cuda(self):return next(self.parameters()).is_cudaclass OmniglotModel(ReptileModel):"""A model for Omniglot classification."""def __init__(self, num_classes):ReptileModel.__init__(self)self.num_classes = num_classesself.conv = nn.Sequential(# 28 x 28 - 1nn.Conv2d(1, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 14 x 14 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 7 x 7 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 4 x 4 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 2 x 2 - 64)self.classifier = nn.Sequential(# 2 x 2 x 64 = 256nn.Linear(256, num_classes),nn.LogSoftmax(1))def forward(self, x):out = x.view(-1, 1, 28, 28)out = self.conv(out)out = out.view(len(out), -1)out = self.classifier(out)return outdef predict(self, prob):__, argmax = prob.max(1)return argmaxdef clone(self):clone = OmniglotModel(self.num_classes)clone.load_state_dict(self.state_dict())if self.is_cuda():clone.cuda()return cloneif __name__ == '__main__':model = OmniglotModel(20)x = torch.zeros(5, 28*28)y = model(x)print('x', x.size())print('y', y.size())
# omniglot.py
from torch.utils import data
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import urllib.request
import zipfile
import shutilfrom utils import list_files, list_dir# 自动下载数据集的URLs
BACKGROUND_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip"
EVALUATION_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip"def download_and_extract_omniglot(root='omniglot'):"""自动下载和提取Omniglot数据集"""if os.path.exists(root) and