KD论文阅读
1.摘要
background
在机器学习领域,集成(ensemble)多个模型来做预测通常能取得比单一模型更好的性能。然而,这种方法的缺点也非常明显:它需要巨大的计算资源和存储空间,导致模型难以部署到对延迟和算力有严格要求的生产环境中,例如移动设备。因此,核心问题是如何将一个强大的集成模型(或一个非常大的单一模型)的知识“压缩”到一个更小、更高效、易于部署的单一模型中,同时尽量不损失其性能。
innovation
1.知识蒸馏 (Knowledge Distillation): 本文提出了一种名为“知识蒸馏”的模型压缩技术。其核心思想是,使用一个已经训练好的、复杂的“教师模型”(cumbersome model)来指导一个轻量的“学生模型”(distilled model)的训练。
2.软目标 (Soft Targets): 传统训练使用独热编码(one-hot)的硬目标 (hard targets),只告诉模型哪个是正确答案。而创新之处在于使用教师模型输出的类别概率向量作为“软目标”。这些软目标不仅包含了正确答案,还揭示了类别之间的相似性信息(例如,一张宝马的图片被误认为卡车的概率远高于被误认为胡萝卜的概率)。这种蕴含在错误答案概率中的信息被称为“暗知识 (dark knowledge)”,能为学生模型的训练提供更丰富、更有效的监督信号。
3.温度系数 (Temperature): 为了让教师模型输出的概率分布更“软”,从而暴露更多暗知识,作者在 softmax 函数中引入了温度系数 T。T > 1 时,概率分布会变得更平滑,使得小概率的负标签也能对损失函数产生影响,从而更好地指导学生学习。学生模型在训练时也使用同样的高温 T 来匹配软目标。
4.与相关工作对比: Caruana 等人的工作 开创了类似的想法,他们通过匹配教师模型 softmax 层之前的 logits 来训练学生模型。本文指出,匹配 logits 是蒸馏在高温 T 极限下的一种特例。蒸馏是一个更通用的框架,并且通过调整温度 T,可以控制忽略那些非常大的负 logits(可能包含噪声),这在实践中可能更有利。
2. 方法 Method
总体流程 (Pipeline)
1.训练教师模型: 首先,在一个大规模数据集上训练一个性能强大但结构复杂的“教师模型”。这个模型可以是一个单一的大型深度网络,也可以是多个模型的集成。
2.生成软目标: 将训练数据(或一个单独的“迁移集”)输入到训练好的教师模型中。对教师模型输出的 logits 使用一个较高的温度 T 通过 softmax 函数,生成软目标概率分布。
3.训练学生模型: 设计一个参数量更少、结构更简单的“学生模型”。其训练的损失函数由两部分加权组成:
蒸馏损失 (Distillation Loss): 学生模型在同样的高温 T 下输出的概率分布与教师模型生成的软目标之间的交叉熵。这部分损失函数引导学生模型模仿教师模型的泛化能力。
学生损失 (Student Loss): 学生模型在温度 T=1(即标准 softmax)下输出的概率分布与真实标签(硬目标)之间的交叉熵。这部分损失函数确保学生模型能从真实数据中学到知识,尤其是在教师模型也可能犯错的情况下,能起到修正作用。
4.部署: 训练完成后,学生模型在推理时使用标准的 T=1 进行预测,从而实现高效部署。
各部分细节
输入:
1.一个预训练好的、高性能的教师模型。
2.一个轻量级的、未训练的学生模型。
3.用于训练的迁移数据集(可以和训练教师模型的数据集相同)。
核心计算:
Softmax with Temperature: qi = exp(zi/T) / Σj exp(zj/T),其中 zi 是 logits,T 是温度。
Total Loss: L = α * L_soft(student_logits/T, teacher_logits/T) + (1 - α) * L_hard(student_logits, true_labels)。L_soft 和 L_hard 都是交叉熵损失函数,α 是超参数,用于平衡两个损失的权重。
输出: 一个训练好的、轻量级的学生模型,其性能接近(有时甚至超过)教师模型,但推理速度更快、占用资源更少。
3. 实验 Experimental Results
数据集:
1.MNIST: 手写数字识别经典数据集。
2.ASR (Automatic Speech Recognition): 一个大规模的内部语音识别数据集,包含约2000小时的语音数据。
3.JFT: 一个谷歌内部的大规模图像数据集,包含1亿张图片和15,000个类别。
实验结论:
1.MNIST 基础验证:
实验目的: 验证知识蒸馏的基本有效性。
结论: 一个大型教师网络达到67个测试错误。一个同样结构但从零开始训练的小型网络有146个错误。而通过蒸馏训练的同一个小型网络,测试错误降至74个。这表明蒸馏成功地将教师模型的泛化能力(例如从数据抖动中学到的知识)迁移到了学生模型。
2.MNIST 迁移学习能力验证:
实验目的: 测试学生模型能否学习到从未见过的类别知识。
结论: 从训练集中移除所有数字“3”的样本后,蒸馏模型仍然能够通过其他数字的软目标学会识别“3”,在调整偏置后,对测试集中“3”的正确率高达98.6%。这有力地证明了软目标中包含了丰富的类别间关系信息。
3.语音识别的实用性验证:
实验目的: 验证蒸馏在真实、大规模商业系统中的效果。
结论: 单个基线模型的词错误率 (WER) 为10.9%,10个模型的集成达到了10.7%。而通过蒸馏得到的单个模型,其 WER 也是10.7%,几乎完全吸收了集成的性能提升,同时部署成本远低于集成模型。
4.JFT数据集上的专家模型:
实验目的: 解决在超大规模数据集上训练集成模型不可行的问题。
结论: 训练一个通用模型和61个专注于区分易混淆类别的“专家模型”。通过结合专家模型,系统的整体测试准确率获得了4.4%的相对提升。这为提升超大模型性能提供了一个可并行的有效路径。
5.软目标作为正则化器:
实验目的: 证明软目标可以有效防止模型在小数据集上过拟合。
结论: 在仅使用3%语音数据的情况下,用硬目标训练的模型严重过拟合(测试准确率44.5%)。而用教师模型(在100%数据上训练)生成的软目标来训练同一个模型,测试准确率达到了57.0%,几乎恢复了在全部数据上训练的性能(58.9%)。
4. 总结 Conclusion
知识蒸馏是一种非常有效且通用的模型压缩和知识迁移框架。它能够将一个复杂模型(或模型集成)所学到的“暗知识”提炼并迁移到一个更小、更快的模型中,使得高性能模型在资源受限环境下的部署成为可能,是连接模型研究与实际应用的重要桥梁。