当前位置: 首页 > news >正文

Reptile元学习算法复现实战:在Omniglot数据集上的少样本学习探索

在深度学习领域中,元学习(Meta-Learning)一直是一个充满挑战性的研究方向。最近我尝试复现了OpenAI提出的Reptile算法,这是一个相对简单但有效的一阶元学习方法。虽然最终的实验结果与论文原始数据存在一定差距,但这个复现过程让我对元学习有了更深入的理解。

初识Reptile算法

当我第一次接触到《On First-Order Meta-Learning Algorithms》这篇论文时,就被Reptile算法的简洁性所吸引。与需要计算二阶导数的MAML算法相比,Reptile的核心思想异常简单:在每个任务上进行几步梯度下降,然后将模型参数朝着任务优化后的参数方向移动。这种"先学习再调整"的策略,让模型能够快速适应新的任务。

元学习的魅力在于它试图解决一个根本性问题:如何让机器像人类一样快速学习新知识。人类在看到几个新字符的样本后,往往能够快速识别相似的字符,这正是少样本学习想要达到的效果。Reptile算法通过在多个相关任务上的训练,让模型学会一个良好的初始化参数,使其能够通过少量梯度步骤快速适应新任务。

代码架构设计

整个项目的代码结构相当清晰,主要由四个Python文件组成。utils.py提供了一些基础的工具函数,包括文件列表获取和最新检查点查找等功能。models.py实现了Reptile算法的核心模型类,其中最关键的是point_grad_to方法,它将梯度设置为当前模型与目标模型参数的差值,这正是Reptile算法的精髓所在。

omniglot.py负责数据集的处理,这个文件让我印象深刻的地方在于它不仅实现了数据加载,还包含了自动下载Omniglot数据集的功能。当我第一次运行代码时,系统自动从GitHub下载了background和evaluation两个数据集,并智能地合并了数据结构。这种自动化的设计大大简化了环境搭建的复杂度。

最后的train_omniglot.py是整个训练流程的核心,包含了完整的元学习训练循环。值得注意的是,代码还贴心地提供了TensorBoard支持和断点恢复功能,这在长时间训练中非常有用。

# utils.py
import os
import re# Those two functions are taken from torchvision code because they are not available on pip as of 0.2.0
def list_dir(root, prefix=False):"""List all directories at a given rootArgs:root (str): Path to directory whose folders need to be listedprefix (bool, optional): If true, prepends the path to each result, otherwiseonly returns the name of the directories found"""root = os.path.expanduser(root)directories = list(filter(lambda p: os.path.isdir(os.path.join(root, p)),os.listdir(root)))if prefix is True:directories = [os.path.join(root, d) for d in directories]return directoriesdef list_files(root, suffix, prefix=False):"""List all files ending with a suffix at a given rootArgs:root (str): Path to directory whose folders need to be listedsuffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').It uses the Python "str.endswith" method and is passed directlyprefix (bool, optional): If true, prepends the path to each result, otherwiseonly returns the name of the files found"""root = os.path.expanduser(root)files = list(filter(lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),os.listdir(root)))if prefix is True:files = [os.path.join(root, d) for d in files]return filesdef find_latest_file(folder):files = []for fname in os.listdir(folder):s = re.findall(r'\d+', fname)if len(s) == 1:files.append((int(s[0]), fname))if files:return max(files)[1]else:return None
# models.py
import torch
from torch import nnclass ReptileModel(nn.Module):def __init__(self):nn.Module.__init__(self)def point_grad_to(self, target):'''Set .grad attribute of each parameter to be proportionalto the difference between self and target'''for p, target_p in zip(self.parameters(), target.parameters()):if p.grad is None:if self.is_cuda():p.grad = torch.zeros(p.size(), device=p.device)else:p.grad = torch.zeros(p.size())p.grad.data.zero_()  # not sure this is requiredp.grad.data.add_(p.data - target_p.data)def is_cuda(self):return next(self.parameters()).is_cudaclass OmniglotModel(ReptileModel):"""A model for Omniglot classification."""def __init__(self, num_classes):ReptileModel.__init__(self)self.num_classes = num_classesself.conv = nn.Sequential(# 28 x 28 - 1nn.Conv2d(1, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 14 x 14 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 7 x 7 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 4 x 4 - 64nn.Conv2d(64, 64, 3, 2, 1),nn.BatchNorm2d(64),nn.ReLU(True),# 2 x 2 - 64)self.classifier = nn.Sequential(# 2 x 2 x 64 = 256nn.Linear(256, num_classes),nn.LogSoftmax(1))def forward(self, x):out = x.view(-1, 1, 28, 28)out = self.conv(out)out = out.view(len(out), -1)out = self.classifier(out)return outdef predict(self, prob):__, argmax = prob.max(1)return argmaxdef clone(self):clone = OmniglotModel(self.num_classes)clone.load_state_dict(self.state_dict())if self.is_cuda():clone.cuda()return cloneif __name__ == '__main__':model = OmniglotModel(20)x = torch.zeros(5, 28*28)y = model(x)print('x', x.size())print('y', y.size())
# omniglot.py
from torch.utils import data
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import urllib.request
import zipfile
import shutilfrom utils import list_files, list_dir# 自动下载数据集的URLs
BACKGROUND_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip"
EVALUATION_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip"def download_and_extract_omniglot(root='omniglot'):"""自动下载和提取Omniglot数据集"""if os.path.exists(root) and 
http://www.dtcms.com/a/289928.html

相关文章:

  • 【AlphaFold3】网络架构篇(1)|概览+预测算法
  • 面试总结第54天微服务开始
  • 基础神经网络模型搭建
  • AI效能之AI单测(一)
  • MCP协议解析:如何通过Model Context Protocol 实现高效的AI客户端与服务端交互
  • c++ duiLib 使用xml文件编写界面布局
  • MyBatis Plus高效开发指南
  • 【PyTorch】图像二分类项目
  • JWT原理及利用手法
  • XTTS实现语音克隆:精确控制音频格式与生成流程【TTS的实战指南】
  • `SearchTransportService` 是 **协调节点与数据节点之间“搜索子请求”通信的运输层**
  • 如何用immich将苹果手机中的照片备份到指定文件夹
  • 开发工具缓存目录
  • 零基础学习性能测试第一章:核心性能指标-响应时间
  • 单链表的手动实现+相关OJ题
  • PostgreSQL 字段类型速查与 Java 枚举映射
  • 【硬件】GalaxyTabPro10.1(SM-T520)刷机/TWRP/LineageOS14/安卓7升级全过程
  • 讲座|人形机器人多姿态站起控制HoST及宇树G1部署
  • C++ 并发 future, promise和async
  • 2025年AIR SCI1区TOP,缩减因子分数阶蜣螂优化算法FORDBO,深度解析+性能实测
  • 基于51单片机的温湿度检测系统Protues仿真设计
  • 创建一个触发csrf的恶意html
  • 低速信号设计之I3C篇
  • windows11环境配置torch-points-kernels库编译安装详细教程
  • 【前端】懒加载(组件/路由/图片等)+预加载 汇总
  • NJU 凸优化导论(10) Approximation+Projection逼近与投影的应用(完结撒花)
  • InfluxDB 数据模型:桶、测量、标签与字段详解(二)
  • springboot --大事件--文章管理接口开发
  • 简洁高效的C++终端日志工具类
  • 响应式编程入门教程第七节:响应式架构与 MVVM 模式在 Unity 中的应用