当前位置: 首页 > news >正文

Dataset Distillation by Matching Training Trajectories(2203.11932)

1. 遇到的问题与解决的方案

遇到的问题
  • 现有方法适用范围局限:多数数据集蒸馏方法仅适用于 MNIST、CIFAR 等低分辨率 “玩具” 数据集,在处理 ImageNet 等真实高分辨率图像时性能大幅下降。

  • 优化与计算效率瓶颈:端到端训练方法(如 DD、KIP)需大量计算和内存,且存在优化不稳定性;单步匹配方法(如 DC、DSA)在多步训练中误差累积,导致性能衰减。

  • 理论与实践脱节:基于核方法(如 KIP)依赖 “无限宽网络” 假设,与实际有限宽度网络的训练动态存在差距。

解决的方案
  • 预计算专家训练轨迹:在真实数据集上预训练多组专家网络,记录每轮迭代的参数快照,形成 “专家轨迹” 作为蒸馏参考基准,避免重复计算。

  • 长程参数轨迹匹配:从专家轨迹中随机采样参数初始化学生网络,通过优化蒸馏数据,使学生网络在 N 步训练后的参数与专家网络 M 步(M>>N)后的参数接近,减少多步误差累积。

  • 内存与计算优化:采用 mini-batch 采样降低内存消耗,引入可训练学习率自动平衡学生与专家更新步数,支持高分辨率数据蒸馏。

1. 专家轨迹的预计算(Expert Trajectories)
  • 核心步骤:在真实数据集上训练多个神经网络,记录每个 epoch 的参数快照,形成 “专家轨迹”(\tau^{}={\theta_{t}^{}}_{0}^{T}),作为蒸馏的参考基准。

  • 实现细节

    • 预训练大量专家网络,保存每轮迭代的参数序列,代表真实数据训练的 “最优轨迹”。

    • 专家轨迹预先计算并存储,避免蒸馏过程中的重复计算,提升效率。

2. 长程参数匹配的优化流程
  • 核心目标:通过优化蒸馏数据(D{syn}),使模型在(D{syn})上训练的参数轨迹与专家轨迹尽可能接近。

  • 具体步骤

    1. 初始化与采样

      • 从预计算的专家轨迹中随机采样某一时刻的参数

        $$
        \(\theta_{t}^{*}\),初始化学生网络参数\(\hat{\theta}_{t} = \theta_{t}^{*}\)
        $$

      • 设定最大起始 epoch (T^{+}),避免使用专家轨迹后期参数变化小的无效片段。

    2. 学生网络更新

      • 在蒸馏数据(D_{syn})上进行N次梯度下降更新,公式为:

        $$
        \(\hat{\theta}_{t+n+1} = \hat{\theta}_{t+n} - \alpha \nabla \ell(\mathcal{A}(b_{t+n}); \hat{\theta}_{t+n})\)
        $$

         

        其中(\mathcal{A})为可微分数据增强,(\alpha)为可训练学习率,(b{t+n})为从(D{syn})采样的 mini-batch。

    3. 损失函数计算

      • 计算学生网络更新N步后的参数(\hat{\theta}{t+N})与专家轨迹中M步后的参数(\theta{t+M}^{*})的归一化 L2 距离:

        $$
        \(\mathcal{L} = \frac{\|\hat{\theta}_{t+N} - \theta_{t+M}^{*}\|_{2}^{2}}{\|\theta_{t}^{*} - \theta_{t+M}^{*}\|_{2}^{2}}\)
        $$

         

        归一化操作确保后期训练中参数变化小时仍有有效梯度信号。

    4. 反向传播与更新

      • 基于损失(\mathcal{L})反向传播,更新蒸馏数据(D_{syn})的像素值和学习率(\alpha),迭代优化直至收敛。

3. 内存优化技术
  • mini-batch 采样

    • 在学生网络更新过程中,每次迭代从(D_{syn})中采样 mini-batch(而非全部图像),减少内存消耗,同时保证同类蒸馏图像的多样性。

  • 可训练学习率

    • 学习率(\alpha)作为可训练参数,自动平衡学生网络N步更新与专家网络M步更新的差异,减少手动调参成本。

4. 算法流程总结(Algorithm 1)
  1. 输入专家轨迹集合、更新步数(M/N)、数据增强函数等;

  2. 初始化蒸馏数据(D_{syn})和可训练学习率(\alpha);

  3. 循环采样专家轨迹、初始化学生网络、执行N次更新、计算损失并反向传播;

  4. 输出优化后的蒸馏数据(D_{syn})。

关键公式与引用段落
  • 学生网络更新公式:1-48🔷

  • 损失函数定义:1-52🔷

  • 内存优化的 mini-batch 更新:1-60🔷

2. 背景

数据集蒸馏的定义与目标
  • 目标是生成小规模合成数据集,使模型在其上训练后的测试精度与完整真实数据集训练结果相当。

  • 与模型蒸馏(Hinton 2015)不同,数据集蒸馏聚焦于 “蒸馏数据” 而非 “模型”,旨在保留数据中与任务相关的关键判别特征。

发展现状与应用
  • 2018 年由 Wang 等人首次提出,后续研究通过学习软标签、梯度匹配等方法改进,但多限于低分辨率场景。

  • 应用场景包括持续学习、神经架构搜索、联邦学习和隐私保护机器学习等。

现有方法的不足
  • 短视性匹配:单步匹配方法(如 DC、DSA)仅优化单步训练行为,无法捕捉长期参数演化规律。

  • 计算成本高:端到端训练需展开多轮迭代(如 DD)或大规模核计算(如 KIP),难以扩展至真实场景。

3. 问题

  • 如何高效模仿真实数据的长程训练动态:现有方法要么局限于单步匹配(误差累积),要么因完整轨迹优化计算昂贵而不可行。

  • 如何突破高分辨率数据蒸馏的瓶颈:高分辨率图像的蒸馏面临内存消耗大、优化难度高的挑战,现有方法难以处理。

  • 如何缩小理论假设与实际训练的差距:基于无限宽网络的方法(如 KIP)在有限宽度网络中性能受限,需更贴近实际训练动态的方法。

4. 动机

  • 专家轨迹的 “黄金标准” 价值:真实数据训练的专家网络轨迹代表数据集蒸馏的理论上限,若蒸馏数据能诱导相似轨迹,可实现性能接近。

  • 平衡优化复杂度与效果:通过匹配轨迹片段(而非完整轨迹),避免端到端优化的计算开销,同时克服单步匹配的短视性。

  • 推动数据集蒸馏的实际应用:使方法适用于高分辨率真实数据(如 ImageNet),拓展其在计算机视觉等领域的实用性。

5. 贡献和结果

核心贡献
  • 方法创新:提出基于训练轨迹匹配的数据集蒸馏框架,通过预计算专家轨迹和长程参数匹配实现高效蒸馏。

  • 性能突破:在 CIFAR-10/100、Tiny ImageNet、ImageNet 子集上显著超越现有方法,如 CIFAR-10 单类 1 图像准确率达 46.3%(原 SOTA 为 28.8%)。

  • 高分辨率蒸馏突破:首次实现 128×128 分辨率 ImageNet 子集的蒸馏,生成可识别的合成图像。

  • 跨架构泛化能力:蒸馏数据在 ResNet、VGG 等不同架构上保持性能,验证方法的鲁棒性。

关键实验结果
  • 低分辨率数据:CIFAR-10 单类 50 图像准确率 71.5%,CIFAR-100 单类 1 图像准确率 24.3%。

  • 中分辨率数据:Tiny ImageNet 单类 10 图像准确率 23.2%,远超同期方法 DM 的 12.9%。

  • 高分辨率数据:ImageNet 子集(128×128)单类 10 图像准确率最高达 63.0%(ImageNette)。

6. 文章结构

  • 摘要(1-4 至 1-5):介绍数据集蒸馏目标、提出轨迹匹配方法及性能优势。

  • 引言(1-6 至 1-19):对比模型蒸馏与数据集蒸馏,分析现有挑战,展示高分辨率蒸馏示例。

  • 相关工作(1-20 至 1-29):综述数据集蒸馏、模仿学习、核心集选择等领域进展。

  • 方法(1-30 至 1-61)

    • 专家轨迹:预计算真实数据训练的参数轨迹。

    • 长程参数匹配:通过损失函数优化蒸馏数据的轨迹相似度。

    • 内存优化:mini-batch 采样与动态学习率。

  • 实验(1-62 至 1-118):在多数据集上对比现有方法,分析长程匹配效果及跨架构泛化。

  • 讨论与局限(1-119 至 1-123):总结方法优势,指出专家轨迹训练的计算成本。

  • 附录(1-179 至 1-245):补充可视化、超参数细节及消融实验结果。

7. 专有名词解释

  • 数据集蒸馏(Dataset Distillation):生成小规模合成数据集,使模型训练后性能接近真实数据训练结果。

  • 专家轨迹(Expert Trajectories):真实数据训练的网络参数随时间变化的序列,作为蒸馏的参考 “黄金标准”。

  • 长程参数匹配(Long-Range Parameter Matching):通过优化蒸馏数据,使模型在 N 步训练后的参数与真实数据 M 步(M >> N)后的参数接近。

  • 可微分增强(Differentiable Augmentation):可反向传播的图像增强技术,用于蒸馏过程中调整合成数据。

  • 核诱导点(KIP, Kernel Inducing Point):基于无限宽网络核方法的数据集蒸馏技术,与有限宽度网络存在性能差距。

8.局限性

  1. 预计算轨迹的存储与计算成本高

    • 为生成专家轨迹,需预先训练大量模型并存储参数快照,导致计算和存储开销显著。例如:

      • CIFAR 数据集的专家训练约需 8 GPU 小时,每个专家轨迹占用约 60MB 存储;

      • ImageNet 子集的专家训练需 15 GPU 小时,每个专家轨迹占用约 120MB 存储。

    • 尽管预计算可重复使用,但首次训练专家网络的时间和资源成本较高,尤其对大规模数据集(如 ImageNet)而言负担较重。

  2. 计算资源需求较高

    • 蒸馏过程中需反向传播通过多轮梯度更新,当处理高分辨率数据(如 128×128 ImageNet)时,内存消耗显著。例如,最大实验使用 6×RTX6000 GPU(144GB 显存)才能支持大规模数据的优化。

    • 虽然通过 mini-batch 采样降低内存压力,但高分辨率图像的蒸馏仍依赖高性能 GPU 集群,限制了方法在资源有限场景下的应用

http://www.dtcms.com/a/265444.html

相关文章:

  • Eclipse主题拓展
  • mysql索引的底层原理是什么?如何回答?
  • Go语言的sync.Once和sync.Cond
  • Redis 源码 tar 包安装 Redis 哨兵模式(Sentinel)
  • Go调度器的抢占机制:从协作式到异步抢占的演进之路|Go语言进阶(7)
  • 价值实证:数字化转型标杆案例深度解析
  • 网络地址与子网划分:一次性搞清 CIDR、VLSM 和子网掩码
  • 分类树查询性能优化:从 2 秒到 0.1 秒的技术蜕变之路
  • 如何在 IDEA 中设置类路径
  • 探索具身智能新高度——机器人在数据收集与学习策略中的优势和机会
  • Objective-C UI事件处理全解析
  • c++中的绑定器
  • 如何使用AI改进论文写作 ---- 引言篇(2)
  • 设计模式系列(10):结构型模式 - 桥接模式(Bridge)
  • AutoMedPrompt的技术,自动优化提示词
  • 【小技巧】Python + PyCharm 小智AI配置MCP接入点使用说明(内测)( PyInstaller打包成 .exe 可执行文件)
  • Spring Boot + 本地部署大模型实现:基于 Ollama 的集成实践
  • Jetson边缘计算主板:Ubuntu 环境配置 CUDA 与 cudNN 推理环境 + OpenCV 与 C++ 进行目标分类
  • 【Note】《深入理解Linux内核》Chapter 9 :深入理解 Linux 内核中的进程地址空间管理机制
  • MySQL数据库----DML语句
  • 深度学习新星:Mamba网络模型与核心模块深度解析
  • Python入门Day2
  • 【第三章:神经网络原理详解与Pytorch入门】01.神经网络算法理论详解与实践-(3)神经网络中的前向传播、反向传播的原理与实现
  • Python中`import` 语句的执行涉及多个步骤
  • 【Python】批量提取超声波检查图片的某一行数据
  • Docker 容器如何实现资源限制(如 CPU 和内存)
  • MacOS Safari 如何打开F12 开发者工具 Developer Tools
  • 【C++】状态模式
  • 好用的自带AI功能的国产IDE
  • Go与Python爬虫对比及模板实现