理解掩码多头注意力机制与教师强制训练
目录
掩码多头注意力机制(Masked Multi-Head Attention)
核心作用:防止模型偷看答案
工作原理:效率与顺序的平衡
具体实现:
工作流程:
实际效果:严格的时间顺序
教师强制训练(Teacher Forcing)
基本概念:相当于训练时的"参考答案"
数据准备:错位设计
为什么需要教师强制?
训练 vs 推理:两种不同模式
完整工作流程
训练阶段,相当于开卷考试:
推理阶段,相当于闭卷考试:
核心要点总结
掩码注意力:防止作弊的机制
教师强制:高效学习的技巧
掩码多头注意力机制(Masked Multi-Head Attention)
核心作用:防止模型偷看答案
在序列生成任务如机器翻译、文本生成中,模型需要像人一样逐步思考,不能提前知道后面的内容。
工作原理:效率与顺序的平衡
-
矛盾点:计算机希望一次性计算所有位置(并行计算提高效率),但序列生成必须按顺序进行
-
解决方案:通过掩码矩阵控制信息流动,既保持并行计算效率,又保证生成顺序
具体实现:
M = [[0, -inf, -inf], # 位置1:只能看自己[0, 0, -inf], # 位置2:只能看位置1-2[0, 0, 0] # 位置3:可以看全部历史 ]
工作流程:
-
计算注意力分数:正常计算 QK^T,得到每个词与其他词的关联程度
-
应用掩码:加上掩码矩阵 M,将未来位置的分数设为负无穷
-
权重归一化:经过 Softmax,未来位置的权重自动变为0
-
生成输出:与 Value 矩阵相乘,得到每个位置的最终表示
实际效果:严格的时间顺序
-
生成第2个词时,只能基于第1个词的信息
-
生成第3个词时,只能基于前两个词的信息
-
就像写文章时,只能根据已写内容继续写,不能提前知道结尾
教师强制训练(Teacher Forcing)
基本概念:相当于训练时的"参考答案"
在训练阶段,我们直接把正确答案作为解码器的输入,帮助模型学习正确的映射关系。
数据准备:错位设计
真实答案: ["I", "love", "you"] 解码器输入: ["<BOS>", "I", "love"] # 右移一位 + 起始符 训练目标: ["I", "love", "you"] # 原始答案
为什么需要教师强制?
没有教师强制的可能会导致曝光偏差:
-
模型第一步预测错误,那么错误结果作为下一步输入,会导致错误不断累积放大
-
就像用错误答案学习,越学越偏
-
训练过程极不稳定,难以收敛
教师强制的优势:
-
稳定学习环境:始终基于正确答案学习
-
避免错误传播:单步错误不会影响后续学习
-
高效训练:可以并行处理整个序列
-
快速收敛:学习曲线更平滑
训练 vs 推理:两种不同模式
阶段 | 输入来源 | 处理方式 | 特点 |
---|---|---|---|
训练 | 真实答案(右移) | 并行处理 | 高效、稳定、有参考答案 |
推理 | 模型自身输出 | 串行生成 | 自主、逐步、无参考答案 |
完整工作流程
训练阶段,相当于开卷考试:
-
编码器理解源序列(如英文句子"I love you")
-
解码器接收"参考答案"(["<BOS>", "I", "love"])
-
模型并行预测每个位置的下一个词
-
与真实答案(["I", "love", "you"])对比计算误差
-
根据误差调整模型参数
推理阶段,相当于闭卷考试:
-
输入起始符
<BOS>
-
模型自主预测第一个词
-
将预测结果加入输入序列
-
重复预测直到生成结束符
<EOS>
-
输出完整序列
核心要点总结
掩码注意力:防止作弊的机制
-
核心作用:确保模型生成时严格遵守时间顺序
-
实现方式:通过上三角掩码矩阵屏蔽未来信息
-
位置:解码器的第一个注意力层
-
效果:保证每个词只能看到它之前的词
教师强制:高效学习的技巧
-
核心作用:提供稳定的训练环境
-
使用阶段:仅在训练时使用,推理时不用
-
关键优势:避免错误累积,大幅提升训练效率
-
数据技巧:正确答案右移一位作为输入