当前位置: 首页 > news >正文

论文阅读笔记:《Dataset Distillation by Matching Training Trajectories》

论文阅读笔记:《Dataset Distillation by Matching Training Trajectories》

    • 1.动机与背景
    • 2.核心方法:轨迹匹配(Trajectory Matching)
    • 3.实验与效果
    • 4.个人思考与启发
    • 主体代码
    • 算法逻辑总结

一句话总结:

这篇论文通过让合成数据”教“学生网络沿着专家轨迹走,从而在极小数据量下实现高性能,开创了数据集蒸馏的新范式。后面很多工作都基于这篇工作来进行改进


CVPR2022 github在这里插入图片描述

1.动机与背景

  • 数据集蒸馏(Dataset Distillation):用一个极小的合成数据集DsynD_{syn}Dsyn训练模型,使其在真实测试集上的性能接近用完整训练集DrealD_{real}Dreal训练的模型。
  • 局限性:先前的梯度匹配方法(详细可看另外一篇博客)只对齐”每一步的梯度“,忽视了模型训练的长程动态;而完全展开多步优化又代价太高、易不稳定。

2.核心方法:轨迹匹配(Trajectory Matching)

  1. 专家轨迹(Expert Trajectories)

    • 离线预先训练若干网络,每隔一个epoch保存一次模型参数{θt∗}t=0T\{\theta_{t}^{*} \}_{t=0}^{T}{θt}t=0T, 得到”专家轨迹“。
    • 这些轨迹代表真实数据训练的”理想路径“,可重复复用,避免再蒸馏时重新训练大模型。
  2. 轨迹对齐(Long-Range Match)

    • 在合成数据上,从专家轨迹的某个起点 θt∗\theta_t^*θt 初始化”学生“参数 θt^\hat{\theta_t}θt^
    • DsynD_{syn}Dsyn上做N步梯度下降:

    θ^t+n+1=θ^t+n−α∇θℓ(Dsyn;θ^t+n)\hat{\theta}_{t+n+1}=\hat{\theta}_{t+n}-α∇_{θ}ℓ(D_{syn};\hat{\theta}_{t+n})θ^t+n+1=θ^t+nαθ(Dsyn;θ^t+n)

    • 对齐学生在步t+N的参数θ^t+N\hat{\theta}_{t+N}θ^t+N与专家在更远的参数θt+M∗\theta_{t+M}^*θt+M, 损失为:
      L=∥θ^t+N−θt+M∗∥22∥θt∗−θt+M∗∥22\mathcal{L} = \frac{\left\| \hat{\theta}_{t+N} - \theta_{t+M}^{*} \right\|_{2}^{2}}{\left\| \theta_{t}^{*} - \theta_{t+M}^{*} \right\|_{2}^{2}} L=θtθt+M22θ^t+Nθt+M22
      分母做归一化,放大信号并自动平衡各层尺度。

    • 外循环(更新合成数据)+内循环(在DsynD_{syn}Dsyn上模拟N步)结果,借助create_graph=True保留计算图,将对齐损失反向传播到合成图像及可学的学生学习率α\alphaα

  3. 内存优化

    • 不一次性对合成集做匹配,而是在学生网络的内循环中按小批次(跨类别但每类少量)更新,既保证”每张图像都被看过“,又大幅节省显存。

3.实验与效果

  • 小样本极端场景:CIFAR-10/100、SVHN 上每类仅 1 或 10 张合成样本,轨迹匹配比梯度匹配提升约 5–10%。
  • 多分辨率验证:Tiny-ImageNet (64×64)、ImageNette/ImageWoof (128×128) 均取得显著增益。
  • 跨架构泛化:虽针对某一网络训练,合成集在 ResNet-18、VGG、AlexNet 等不同模型上依旧表现稳健。
  • 消融分析:轨迹长度 MM、内循环步数 NN、匹配目标(参数 vs 输出)、专家轨迹数量等均对性能有明显影响,验证设计合理性。

4.个人思考与启发

  • ”长程轨迹对齐“胜于”短程梯度对齐“:对齐训练轨迹(”路径“)往往比对齐某一步”梯度“更能保证学习行为一致。
  • 虽然可以通过预存储轨迹和小批次策略提高效率,但是仍然很耗内存。在训练教师模型的时候,需要把多个教师轨迹存储下来,在训练学生模型的时候需要把训练的参数记录下来,占用大量的内存与显存。同时复杂的双层优化,难以避免复杂度高。

主体代码

   ''' training '''# 将合成图像与LR设为可优化image_syn = image_syn.detach().to(args.device).requires_grad_(True)syn_lr = syn_lr.detach().to(args.device).requires_grad_(True)optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5)# 学习率也设置为可优化optimizer_lr = torch.optim.SGD([syn_lr], lr=args.lr_lr, momentum=0.5)optimizer_img.zero_grad()criterion = nn.CrossEntropyLoss().to(args.device)print('%s training begins'%get_time())# 专家轨迹路径expert_dir = os.path.join(args.buffer_path, args.dataset)if args.dataset == "ImageNet":expert_dir = os.path.join(expert_dir, args.subset, str(args.res))if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:expert_dir += "_NO_ZCA"expert_dir = os.path.join(expert_dir, args.model)print("Expert Dir: {}".format(expert_dir))# 加载或部分加载专家轨迹if args.load_all:buffer = []n = 0while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):buffer = buffer + torch.load(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))n += 1if n == 0:raise AssertionError("No buffers detected at {}".format(expert_dir))else:expert_files = []n = 0while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))n += 1if n == 0:raise AssertionError("No buffers detected at {}".format(expert_dir))file_idx = 0expert_idx = 0random.shuffle(expert_files)if args.max_files is not None:expert_files = expert_files[:args.max_files]print("loading file {}".format(expert_files[file_idx]))buffer = torch.load(expert_files[file_idx])if args.max_experts is not None:buffer = buffer[:args.max_experts]random.shuffle(buffer)# 记录最佳精度与方差best_acc = {m: 0 for m in model_eval_pool}best_std = {m: 0 for m in model_eval_pool}# --- 蒸馏迭代主循环 ---for it in range(0, args.Iteration+1):save_this_it = False   # 标记本次迭代是否是要保存的最佳合成数据# 将当前迭代进度记录到 Weights & Biases (W&B)# writer.add_scalar('Progress', it, it)wandb.log({"Progress": it}, step=it)''' Evaluate synthetic data '''# 如果当前迭代在预设的评估点列表中,则评估合成数据在随机模型上的表现if it in eval_it_pool:for model_eval in model_eval_pool:print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))# 打印使用的数据增强策略if args.dsa:print('DSA augmentation strategy: \n', args.dsa_strategy)print('DSA augmentation parameters: \n', args.dsa_param.__dict__)else:print('DC augmentation parameters: \n', args.dc_aug_param)accs_test = []  # 存储每次评估的测试准确率accs_train = []  # 存储每次评估的训练准确率# 重复num_eval 次随机初始化的模型评估,以平均化随机性for it_eval in range(args.num_eval):# 随机初始化一个新模型net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random modeleval_labs = label_syn# 固定合成图像与标签,避免在评估时被意外修改with torch.no_grad():image_save = image_synimage_syn_eval, label_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification# 将当前合成学习率传递给评估函数args.lr_net = syn_lr.item()# 用合成数据训练并评估 net_eval,返回 (loss, train_acc, test_acc)_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture)accs_test.append(acc_test)accs_train.append(acc_train)accs_test = np.array(accs_test)accs_train = np.array(accs_train)acc_test_mean = np.mean(accs_test)acc_test_std = np.std(accs_test)# 如果有新的最佳平均准确率,则更新best_acc并标记保存if acc_test_mean > best_acc[model_eval]:best_acc[model_eval] = acc_test_meanbest_std[model_eval] = acc_test_stdsave_this_it = Trueprint('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), model_eval, acc_test_mean, acc_test_std))# 将评估结果记录到 W&Bwandb.log({'Accuracy/{}'.format(model_eval): acc_test_mean}, step=it)wandb.log({'Max_Accuracy/{}'.format(model_eval): best_acc[model_eval]}, step=it)wandb.log({'Std/{}'.format(model_eval): acc_test_std}, step=it)wandb.log({'Max_Std/{}'.format(model_eval): best_std[model_eval]}, step=it)# 如果评估改进或周期点,保存合成图像到 W&B 与本地if it in eval_it_pool and (save_this_it or it % 1000 == 0):with torch.no_grad():image_save = image_syn.cuda()save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name)if not os.path.exists(save_dir):os.makedirs(save_dir)# 保存当前迭代的合成图像与标签torch.save(image_save.cpu(), os.path.join(save_dir, "images_{}.pt".format(it)))torch.save(label_syn.cpu(), os.path.join(save_dir, "labels_{}.pt".format(it)))# 如果达成新最佳,还额外保存为 bestif save_this_it:torch.save(image_save.cpu(), os.path.join(save_dir, "images_best.pt".format(it)))torch.save(label_syn.cpu(), os.path.join(save_dir, "labels_best.pt".format(it)))# 将像素分布记录为 W&B 直方图wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it)# 可视化合成图像:若 ipc<50 或 强制保存,则进行网格化展示if args.ipc < 50 or args.force_save:upsampled = image_saveif args.dataset != "ImageNet":# 针对 CIFAR 类数据,将低分辨率图像放大 4 倍以便观察upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)for clip_val in [2.5]:std = torch.std(image_save)mean = torch.mean(image_save)upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std)if args.dataset != "ImageNet":upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)# 如果使用 ZCA 预处理,还需要保存和可视化反变换后的图像if args.zca:image_save = image_save.to(args.device)image_save = args.zca_trans.inverse_transform(image_save)image_save.cpu()torch.save(image_save.cpu(), os.path.join(save_dir, "images_zca_{}.pt".format(it)))upsampled = image_saveif args.dataset != "ImageNet":upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)for clip_val in [2.5]:std = torch.std(image_save)mean = torch.mean(image_save)upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std)if args.dataset != "ImageNet":upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)# 记录当前合成学习率到 W&Bwandb.log({"Synthetic_LR": syn_lr.detach().cpu()}, step=it)# --- 学生模型初始化与专家轨迹抽样 ---# 随机初始化学生网络并转换为 ReparamModule 以支持扁平化权重student_net = get_network(args.model, channel, num_classes, im_size, dist=False).to(args.device)  # get a random modelstudent_net = ReparamModule(student_net)if args.distributed:student_net = torch.nn.DataParallel(student_net)student_net.train()# 计算网络参数总数,用于后续损失归一化num_params = sum([np.prod(p.size()) for p in (student_net.parameters())])# 从 buffer 中轮询或随机获取一条专家轨迹if args.load_all:expert_trajectory = buffer[np.random.randint(0, len(buffer))]else:expert_trajectory = buffer[expert_idx]expert_idx += 1if expert_idx == len(buffer):expert_idx = 0file_idx += 1# 如果切换到下一个 buffer 文件,则重新加载并打乱if file_idx == len(expert_files):file_idx = 0random.shuffle(expert_files)print("loading file {}".format(expert_files[file_idx]))if args.max_files != 1:del bufferbuffer = torch.load(expert_files[file_idx])if args.max_experts is not None:buffer = buffer[:args.max_experts]random.shuffle(buffer)# 从专家轨迹中随机选择起始epoch和目标epoch参数start_epoch = np.random.randint(0, args.max_start_epoch)starting_params = expert_trajectory[start_epoch]target_params = expert_trajectory[start_epoch+args.expert_epochs]# 将参数列表展平成单个向量target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params], 0)student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)]starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0)syn_images = image_syn # 合成图像集合y_hat = label_syn.to(args.device)# 准备列表保存中间参数损失与距离param_loss_list = []param_dist_list = []indices_chunks = [] # 用于分批操作的索引缓存# --- 合成数据多步梯度更新模拟 ---for step in range(args.syn_steps):# 如果当前无可用indices_chunks,则重新打乱并拆分if not indices_chunks:indices = torch.randperm(len(syn_images))indices_chunks = list(torch.split(indices, args.batch_syn))these_indices = indices_chunks.pop()x = syn_images[these_indices]   # 取当前批次的合成图像this_y = y_hat[these_indices]   # 对应标签# texture 模式下,进行随机平移并裁剪模拟纹理拼接if args.texture:x = torch.cat([torch.stack([torch.roll(im, (torch.randint(im_size[0]*args.canvas_size, (1,)), torch.randint(im_size[1]*args.canvas_size, (1,))), (1,2))[:,:im_size[0],:im_size[1]] for im in x]) for _ in range(args.canvas_samples)])this_y = torch.cat([this_y for _ in range(args.canvas_samples)])# 可微增强替代普通数据增强if args.dsa and (not args.no_aug):x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)if args.distributed:forward_params = student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)else:forward_params = student_params[-1]# 前向计算 logitsx = student_net(x, flat_param=forward_params)ce_loss = criterion(x, this_y)# 计算损失对扁平化参数的梯度(保留图以继续反向到合成图像)grad = torch.autograd.grad(ce_loss, student_params[-1], create_graph=True)[0]# 更新学生参数向下一个步长student_params.append(student_params[-1] - syn_lr * grad)# --- 计算参数匹配损失 ---param_loss = torch.tensor(0.0).to(args.device)param_dist = torch.tensor(0.0).to(args.device)param_loss += torch.nn.functional.mse_loss(student_params[-1], target_params, reduction="sum")param_dist += torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum")param_loss_list.append(param_loss)param_dist_list.append(param_dist)# 归一化:先按参数总数,再除以起点-目标距离param_loss /= num_paramsparam_dist /= num_paramsparam_loss /= param_distgrand_loss = param_loss# --- 更新合成图像与合成学习率 ---optimizer_img.zero_grad()optimizer_lr.zero_grad()grand_loss.backward()optimizer_img.step()optimizer_lr.step()# 记录损失与起始 epochwandb.log({"Grand_Loss": grand_loss.detach().cpu(),"Start_Epoch": start_epoch})# 清理中间梯度缓存,避免显存泄漏for _ in student_params:del _# 每 10 次迭代打印一次损失信息if it%10 == 0:print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))wandb.finish()

算法逻辑总结

  1. 准备”专家示范“
    • 先用真实大数据集训练一个(或多组)模型,把模型再每个训练轮/每步的所有参数都记下来,这条参数随实践变化的记录叫做”专家轨迹“。
  2. 初始化”学生“
    • 选轨迹上某个时间点,把学生网络的参数初始化为老师当时的状态。这样学生和老师从同一个起点出发。
  3. 学生用合成数据学N步
    • 用我们的小合成数据集让学生网络做N步梯度下降(就是跑N个小批次的训练)。
    • 记录学生跑完这N步后得到的新参数。
  4. 对齐”未来“
    • 看老师在真实训练中,从同一个起点走M步后参数是怎么样,把学生此刻的参数和老师未来第M步的参数做对比。
    • 差距越小,说明学生越像老师;差距越大,说明合成数据还不够好。
  5. 更新合成数据
    • 把这个”未来对齐“误差当作损失,反向传播回去,去调整我们的小合成图像(和一个”学生学习率“参数)。
    • 目的就是让下一轮学生训练时候,能更快更准确地朝着老师的轨迹走。
  6. 重复很多轮
    • 每轮都重新从专家轨迹选一个起点,反复做上面四步,让小合成数据不断进化、越来越”聪明“。
http://www.dtcms.com/a/314335.html

相关文章:

  • 【数据结构初阶】--算法复杂度详解
  • 登录弹窗,cv直接使用
  • 【FreeRTOS】系统时钟配置
  • HTTP基本结构
  • ICCV 2025|单视频生成动态4D场景!中科大微软突破4D生成瓶颈,动画效果炸裂来袭!
  • ICCV 2025|可灵团队新作 ReCamMaster:从单视频到多视角生成,多角度看好莱坞大片
  • socket与udp
  • 折叠屏网页布局挑战:响应式设计在工业平板与PC端的弹性适配策略
  • 【Mac】OrbStack:桌面端虚拟机配置与使用
  • LeetCode 140:单词拆分 II
  • 【MySQL03】:MySQL约束
  • mac 技巧
  • 零售消费行业研究系列报告
  • Java-基础-统计投票信息
  • Linux下载安装mysql,客户端(Navicat)连接Linux中的mysql
  • allegro建库--1
  • 【Redis】移动设备离线通知推送全流程实现:系统推送服务与Redis的协同应用
  • 模型学习系列之考试
  • 机器学习(8):线性回归
  • 基于落霞归雁思维框架的自动化测试实践与探索
  • OpenLayers 入门指南【五】:Map 容器
  • Unity发布Android平台实现网页打开应用并传参
  • 如何查看 iOS 电池与电耗:入门指南与实战工具推荐
  • 期权投资盈利之道书籍推荐
  • Codeforces Round 1008 (Div. 2)
  • Chrontel【CH7214C-BF】CH7214C USB Type C Logic Controller
  • 【Java线程池深入解析:从入门到精通】
  • Memcached 缓存详解及常见问题解决方案
  • 【深度学习新浪潮】近三年城市级数字孪生的研究进展一览
  • 【音视频】WebRTC 一对一通话-实现概述