Gemma 3 报告中的蒸馏
Gemma 3 报告中的蒸馏
flyfish
先看效果
模型对比
-
Gemma 3-4B-IT 与 Gemma 2-27B-IT 的对比
Gemma3-4B-IT competitive with Gemma2-27B-IT,即经过蒸馏训练后,Gemma 3的40亿参数指令微调版(Gemma 3-4B-IT)能够与Gemma 2的270亿参数指令微调版(Gemma 2-27B-IT)相抗衡。这一对比体现了蒸馏技术的效果——通过从更大型模型蒸馏知识,小参数模型实现了对前代大参数模型的性能追赶。 -
Gemma 3-27B-IT 与 Gemini-1.5-Pro 的对比
Gemma3-27B-IT comparable to Gemini-1.5-Pro across benchmarks,即Gemma 3的270亿参数指令微调版(Gemma 3-27B-IT)在各项基准测试中可与Gemini-1.5-Pro相媲美。这说明蒸馏技术帮助Gemma 3的大参数模型达到了同级别前沿模型的性能水平。
小参数新模型(Gemma 3-4B-IT)对标前代大参数模型(Gemma 2-27B-IT);
大参数新模型(Gemma 3-27B-IT)对标同级别前沿模型(Gemini-1.5-Pro)。
两者均通过蒸馏技术实现性能跃升,且教师模型限定为 Gemini 系列前沿模型
The pre-training optimization recipe is similar to Gemma 2, with some modifications in the architecture design. We use the same tokenizer as Gemini 2.0, and we also revisit our data mixture to improve the multilingual capabilities of the models, while introducing image understanding. All Gemma 3 models are trained with knowledge distillation。
预训练优化方案与 Gemma 2 类似,但在架构设计上做了一些调整。我们使用与 Gemini 2.0 相同的分词器,同时重新调整了数据混合方式,以提升模型的多语言能力,此外还新增了图像理解功能。所有 Gemma 3 模型均采用知识蒸馏技术进行训练。
蒸馏具体做法
We sample 256 logits per token, weighted by teacher probabilities. The student learns the teacher’s distribution within these samples via cross-entropy loss. The teacher’s target distribution is set to zero probability for non-sampled logits, and renormalized.
基于logit采样的交叉熵优化策略
-
定向采样机制:针对每个输入token,从教师模型输出的全部logits(即未归一化的预测分数)中,按教师模型对各logit的概率分布进行加权采样,选取256个logits。这一操作的核心是让学生模型优先关注教师模型赋予高概率的关键信息,避免无差别学习低价值内容,从而提升学习效率。
-
目标分布修正:对于未被选中的logits,教师模型的目标概率被设定为0;同时,对已选中的256个logits的概率进行重新归一化处理(即调整其概率值,使总和为1)。这一修正确保了目标分布的规范性(满足概率之和为1的基本要求),同时明确了学生模型的学习边界——仅需拟合教师模型在关键logits上的分布。
-
损失驱动学习:学生模型通过交叉熵损失函数计算自身对这256个采样logits的预测分布与教师模型修正后目标分布的差异,并基于此差异调整参数,使自身分布尽可能逼近教师分布。这一过程通过量化“学生与教师的判断差异”,精准驱动学生模型继承教师的核心知识判断逻辑。