【人工智能99问】transformer的编码器和解码器是如何协同工作的?(15/99)
文章目录
- transformer的编码器和解码器是如何协同工作的?
- 一、 训练阶段
- 1.1 编码器的工作流程
- 1.2 解码器的工作流程
- 1.3 协同工作特点
- 二、 推理阶段
- 2.1 编码器的工作流程
- 2.2 解码器的工作流程
- 2.3 协同工作特点
- 三、 关键差异总结
- 四、 可视化示例
- 训练阶段数据流:
- 推理阶段数据流:
- 五、 技术细节补充
transformer的编码器和解码器是如何协同工作的?
Transformer 的编码器(Encoder)和解码器(Decoder)在训练和推理时的协同工作方式有显著差异,主要体现在数据流动、注意力机制的应用以及对上下文信息的处理上。以下是详细分析:
一、 训练阶段
在训练时,模型通过并行处理输入输出序列,利用真实标签(Ground Truth)进行监督学习,并通过自回归(Autoregressive)和交叉注意力(Cross-Attention)机制学习映射关系。
1.1 编码器的工作流程
-
输入处理:
- 输入序列(如源语言句子
X = [x₁, x₂, ..., xₙ]
)经过嵌入层和位置编码后,输入编码器。 - 编码器由多层自注意力(Self-Attention)和前馈网络(FFN)组成,计算每个 token 的上下文表示。
- 输出编码后的特征矩阵
H_enc = [h₁, h₂, ..., hₙ]
,捕获输入序列的全局信息。
- 输入序列(如源语言句子
-
自注意力机制:
- 编码器自注意力允许每个 token 关注输入序列的所有其他 token,建立全局依赖关系。
1.2 解码器的工作流程
-
输入处理:
- 目标序列(如目标语言句子
Y = [y₁, y₂, ..., yₘ]
)被右移(Shifted Right)并添加起始符(如<SOS>
),形成解码器输入Y_shifted = [<SOS>, y₁, ..., yₘ₋₁]
。 - 嵌入层和位置编码后输入解码器。
- 目标序列(如目标语言句子
-
掩码自注意力(Masked Self-Attention):
- 解码器第一层是掩码自注意力,确保每个 token 仅关注当前位置及之前的 token(防止信息泄露)。
-
交叉注意力(Cross-Attention):
- 解码器的第二层通过交叉注意力将
H_enc
作为 Key 和 Value,解码器上一层的输出作为 Query,对齐编码器和解码器的信息。 - 例如,翻译时解码器通过交叉注意力聚焦源语言的相关部分(如单词对齐)。
- 解码器的第二层通过交叉注意力将
-
输出预测:
- 解码器的最终输出通过线性层和 Softmax 生成概率分布,预测下一个 token
yₜ
。 - 损失函数(如交叉熵)基于预测结果和真实标签
yₜ
计算。
- 解码器的最终输出通过线性层和 Softmax 生成概率分布,预测下一个 token
1.3 协同工作特点
- 并行训练:解码器一次性处理整个目标序列(掩码保证自回归性质),而非逐 token 生成。
- 教师强制(Teacher Forcing):解码器使用真实标签
Y_shifted
作为输入,而非自身预测结果,加速收敛。
二、 推理阶段
在推理时,模型需逐 token 生成输出序列,依赖前一步的预测结果,无法并行处理。
2.1 编码器的工作流程
- 与训练阶段相同:输入序列
X
编码为H_enc
,仅需计算一次(可缓存以提升效率)。
2.2 解码器的工作流程
-
初始化:
- 输入为起始符
<SOS>
,初始上下文为空。
- 输入为起始符
-
自回归生成:
- 每一步
t
,解码器基于已生成的[y₁, ..., yₜ₋₁]
预测下一个 tokenyₜ
:- 通过掩码自注意力处理当前序列。
- 通过交叉注意力结合编码器输出
H_enc
。 - 输出层生成
yₜ
的概率分布,采样(如贪婪搜索或 Beam Search)得到yₜ
。
- 将
yₜ
拼接到输入中,继续生成yₜ₊₁
,直到输出终止符(如<EOS>
)。
- 每一步
-
缓存机制:
- 解码器的自注意力键值(K/V)可缓存,避免重复计算历史 token。
2.3 协同工作特点
- 串行生成:解码器必须等待上一步完成才能继续,速度较慢。
- 动态输入:解码器输入随生成过程动态增长,依赖编码器的固定输出
H_enc
。
三、 关键差异总结
方面 | 训练阶段 | 推理阶段 |
---|---|---|
解码器输入 | 使用真实标签(右移) | 使用自身预测结果(自回归) |
并行性 | 编码器和解码器均并行处理序列 | 解码器必须串行生成 |
注意力掩码 | 掩码未来 token(仅解码器) | 同训练,但动态扩展掩码 |
编码器输出利用 | 每次迭代重新计算 | 计算一次并缓存 |
效率 | 高(批量处理) | 低(逐 token 生成) |
四、 可视化示例
训练阶段数据流:
编码器: X → Self-Attention → H_enc
解码器: Y_shifted → Masked Self-Attention → Cross-Attention(H_enc) → yₜ_pred
推理阶段数据流:
Step 1: 编码器 X → H_enc
Step 2: 解码器 <SOS> → y₁
Step 3: 解码器 <SOS>, y₁ → y₂
...
Step N: 解码器 <SOS>, y₁, ..., yₙ → <EOS>
五、 技术细节补充
- 位置编码:确保模型感知 token 位置,在训练和推理中均需添加。
- Beam Search:推理时可通过束搜索提高生成质量,保留多条候选路径。
- 缓存优化:如 Transformer-XL 或 KV Cache 技术加速推理。
通过这种设计,Transformer 在训练时高效学习长程依赖,在推理时通过自回归生成保持一致性。