当前位置: 首页 > news >正文

KD论文阅读

1.摘要

background

      在机器学习领域,集成(ensemble)多个模型来做预测通常能取得比单一模型更好的性能。然而,这种方法的缺点也非常明显:它需要巨大的计算资源和存储空间,导致模型难以部署到对延迟和算力有严格要求的生产环境中,例如移动设备。因此,核心问题是如何将一个强大的集成模型(或一个非常大的单一模型)的知识“压缩”到一个更小、更高效、易于部署的单一模型中,同时尽量不损失其性能。

innovation

      1.知识蒸馏 (Knowledge Distillation): 本文提出了一种名为“知识蒸馏”的模型压缩技术。其核心思想是,使用一个已经训练好的、复杂的“教师模型”(cumbersome model)来指导一个轻量的“学生模型”(distilled model)的训练。

      2.软目标 (Soft Targets): 传统训练使用独热编码(one-hot)的硬目标 (hard targets),只告诉模型哪个是正确答案。而创新之处在于使用教师模型输出的类别概率向量作为“软目标”。这些软目标不仅包含了正确答案,还揭示了类别之间的相似性信息(例如,一张宝马的图片被误认为卡车的概率远高于被误认为胡萝卜的概率)。这种蕴含在错误答案概率中的信息被称为“暗知识 (dark knowledge)”,能为学生模型的训练提供更丰富、更有效的监督信号。

      3.温度系数 (Temperature): 为了让教师模型输出的概率分布更“软”,从而暴露更多暗知识,作者在 softmax 函数中引入了温度系数 T。T > 1 时,概率分布会变得更平滑,使得小概率的负标签也能对损失函数产生影响,从而更好地指导学生学习。学生模型在训练时也使用同样的高温 T 来匹配软目标。

      4.与相关工作对比: Caruana 等人的工作 开创了类似的想法,他们通过匹配教师模型 softmax 层之前的 logits 来训练学生模型。本文指出,匹配 logits 是蒸馏在高温 T 极限下的一种特例。蒸馏是一个更通用的框架,并且通过调整温度 T,可以控制忽略那些非常大的负 logits(可能包含噪声),这在实践中可能更有利。

2. 方法 Method

总体流程 (Pipeline)

1.训练教师模型: 首先,在一个大规模数据集上训练一个性能强大但结构复杂的“教师模型”。这个模型可以是一个单一的大型深度网络,也可以是多个模型的集成。

2.生成软目标: 将训练数据(或一个单独的“迁移集”)输入到训练好的教师模型中。对教师模型输出的 logits 使用一个较高的温度 T 通过 softmax 函数,生成软目标概率分布。

3.训练学生模型: 设计一个参数量更少、结构更简单的“学生模型”。其训练的损失函数由两部分加权组成:

蒸馏损失 (Distillation Loss): 学生模型在同样的高温 T 下输出的概率分布与教师模型生成的软目标之间的交叉熵。这部分损失函数引导学生模型模仿教师模型的泛化能力。

学生损失 (Student Loss): 学生模型在温度 T=1(即标准 softmax)下输出的概率分布与真实标签(硬目标)之间的交叉熵。这部分损失函数确保学生模型能从真实数据中学到知识,尤其是在教师模型也可能犯错的情况下,能起到修正作用。

4.部署: 训练完成后,学生模型在推理时使用标准的 T=1 进行预测,从而实现高效部署。

各部分细节

输入:

1.一个预训练好的、高性能的教师模型。

2.一个轻量级的、未训练的学生模型。

3.用于训练的迁移数据集(可以和训练教师模型的数据集相同)。

核心计算:

Softmax with Temperature: qi = exp(zi/T) / Σj exp(zj/T),其中 zi 是 logits,T 是温度。

Total Loss: L = α * L_soft(student_logits/T, teacher_logits/T) + (1 - α) * L_hard(student_logits, true_labels)。L_soft 和 L_hard 都是交叉熵损失函数,α 是超参数,用于平衡两个损失的权重。

输出: 一个训练好的、轻量级的学生模型,其性能接近(有时甚至超过)教师模型,但推理速度更快、占用资源更少。

3. 实验 Experimental Results

数据集:

1.MNIST: 手写数字识别经典数据集。

2.ASR (Automatic Speech Recognition): 一个大规模的内部语音识别数据集,包含约2000小时的语音数据。

3.JFT: 一个谷歌内部的大规模图像数据集,包含1亿张图片和15,000个类别。

实验结论:

1.MNIST 基础验证:

实验目的: 验证知识蒸馏的基本有效性。

结论: 一个大型教师网络达到67个测试错误。一个同样结构但从零开始训练的小型网络有146个错误。而通过蒸馏训练的同一个小型网络,测试错误降至74个。这表明蒸馏成功地将教师模型的泛化能力(例如从数据抖动中学到的知识)迁移到了学生模型。

2.MNIST 迁移学习能力验证:

实验目的: 测试学生模型能否学习到从未见过的类别知识。

结论: 从训练集中移除所有数字“3”的样本后,蒸馏模型仍然能够通过其他数字的软目标学会识别“3”,在调整偏置后,对测试集中“3”的正确率高达98.6%。这有力地证明了软目标中包含了丰富的类别间关系信息。

3.语音识别的实用性验证:

实验目的: 验证蒸馏在真实、大规模商业系统中的效果。

结论: 单个基线模型的词错误率 (WER) 为10.9%,10个模型的集成达到了10.7%。而通过蒸馏得到的单个模型,其 WER 也是10.7%,几乎完全吸收了集成的性能提升,同时部署成本远低于集成模型。

4.JFT数据集上的专家模型:

实验目的: 解决在超大规模数据集上训练集成模型不可行的问题。

结论: 训练一个通用模型和61个专注于区分易混淆类别的“专家模型”。通过结合专家模型,系统的整体测试准确率获得了4.4%的相对提升。这为提升超大模型性能提供了一个可并行的有效路径。

5.软目标作为正则化器:

实验目的: 证明软目标可以有效防止模型在小数据集上过拟合。

结论: 在仅使用3%语音数据的情况下,用硬目标训练的模型严重过拟合(测试准确率44.5%)。而用教师模型(在100%数据上训练)生成的软目标来训练同一个模型,测试准确率达到了57.0%,几乎恢复了在全部数据上训练的性能(58.9%)。

4. 总结 Conclusion

       知识蒸馏是一种非常有效且通用的模型压缩和知识迁移框架。它能够将一个复杂模型(或模型集成)所学到的“暗知识”提炼并迁移到一个更小、更快的模型中,使得高性能模型在资源受限环境下的部署成为可能,是连接模型研究与实际应用的重要桥梁。

http://www.dtcms.com/a/420522.html

相关文章:

  • wordpress设计的网站厦门网站建设公司哪家好
  • 南阳网站推广价格dede织梦网站
  • 【Docker】DockerHub拉取镜像
  • 跨域问题产生的原因及解决方法
  • Python的typing模块:类型提示 (Type Hinting)
  • 建设岗位考试网站投资公司注册资金多少
  • 建设部资质升级网站天津建设
  • WebSocket实战:打造AI流式对话的实时通信基础
  • 安徽品质网站建设创新哈尔滨快速建站模板
  • 二十二、RJ45黄绿指示灯闪烁的“底层逻辑”
  • 网站运营怎样做php小网站
  • 闵行网站设计博敏 网站开发
  • 莱芜论坛莱芜都市网单页面优化
  • html写手机网站吗哪个网站做的w7系统好
  • 用dw做网站的菜单栏网站程序源码下载
  • 简单免费自建网站达州网站开发
  • 手机建站最好的网站wordpress中运行程序
  • 真题题解:国王金币发放模型解析(洛谷P2669)
  • DSP28335 SCI 串口回显功能案例解析
  • dede电影网站网站建设与网络编辑综合实训课程指导手册
  • 网站建设费税率多少韶关做网站公司
  • 四川省建设主管部门网站有源码如何搭建app
  • logo在线设计图片seo外链平台热狗
  • 做网站 阿里云滨湖区建设局官方网站
  • 广州建站服务电商公司网站建设流程
  • 专门建设网站的公司网上商城官网入口
  • 建网站设置网站首页网站版面布局设计的原则
  • Altium Designer(AD)PCB拼版——两种方法教程
  • 模板板网站网站备案需要多久时间
  • 百度搜索引擎网站开发公司资质等级