当前位置: 首页 > news >正文

Transformer超详细全解!含代码实战

目录

一. Transformer的整体架构

二. Transformer 的输入

2.1 单词 Embedding(词嵌入层)

2.2  位置 Embedding(位置编码)

(1)正余弦位置编码(论文采用)

(2)可学习位置编码

三. Self-Attention(自注意力机制)和Multi-Head Attention(多头自注意力)

3.1  Self-Attention(自注意力机制)

3.2  Multi-Head Attention(多头注意力)

四.Encoder (编码器)结构

4.1  Add & Norm层

4.2  Feed Forward Network(FFN)层

4.3  完整的Encoder(编码器)架构

五.Decoder (解码器)结构

5.1  Masked 操作

5.2  Cross-Attention(交叉注意力)  

5.3  完整的Decoder(解码器)架构

六.Transformer完整代码实现


参考:
Transformer模型详解(图解最完整版) - 知乎https://zhuanlan.zhihu.com/p/338817680GitHub - liaoyanqing666/transformer_pytorch: 完整的原版transformer程序,complete origin transformer programhttps://github.com/liaoyanqing666/transformer_pytorcharxiv.org/pdf/1706.03762https://arxiv.org/pdf/1706.03762

一. Transformer的整体架构

Transformer 由 Encoder (编码器)和 Decoder (解码器)两个部分组成,Encoder 和 Decoder 都包含 6 个 block(块)。Transformer 的工作流程大体如下:

第一步:获取输入句子的每一个单词的表示向量 X,X由单词本身的 Embedding(Embedding就是从原始数据提取出来的特征(Feature)) 和单词位置的 Embedding 相加得到。

第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x)传入 Encoder 中,经过 6 个 Encoder block (编码器块)后可以得到句子所有单词的编码信息矩阵 C。如下图,单词向量矩阵用 X_{n\times d}表示, n 是句子中单词个数,d 是表示向量的维度(论文中 d=512)。每一个 Encoder block (编码器块)输出的矩阵维度与输入完全一致。

第三步:将 Encoder (编码器)输出的编码信息矩阵 C传递到 Decoder(解码器)中,Decoder(解码器) 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。

上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "<Begin>",预测第一个单词 "I";然后输入翻译开始符 "<Begin>" 和单词 "I",预测单词 "have",以此类推。

二. Transformer 的输入

Transformer 中单词的输入表示 单词本身的 Embedding 和单词位置 Embedding (Positional Encoding)相加得到。

2.1 单词 Embedding(词嵌入层)

单词本身的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

self.embedding = nn.Embedding(vocabulary, dim)

功能解释:

  1. 作用:将离散的整数索引(单词ID)转换为连续的向量表示

  2. 输入:形状为 [sequence_length] 的整数张量

  3. 输出:形状为 [sequence_length, dim] 的浮点数张量(X_{n\times d},n是序列长度,d是特征维度)

参数详解:

参数含义示例值说明
vocabulary词汇表大小10000表示模型能处理的不同单词/符号总数
dim嵌入维度512每个单词被表示成的向量长度

工作原理:

  1. 创建一个可学习的嵌入矩阵[vocabulary, dim],例如当 vocabulary=10000dim=512 时,是一个 10000×512 的矩阵;

  2. 每个整数索引对应矩阵中的一行:

# 假设单词"apple"的ID=42
apple_vector = embedding_matrix[42]  # 形状 [512]

在Transformer中的具体作用:

# 输入:src = torch.randint(0, 10000, (2, 10))
# 形状:[batch_size=2, seq_len=10]src_embedded = self.embedding(src)# 输出形状变为:[2, 10, 512]
# 每个整数单词ID被替换为512维的向量

可视化表现:

原始输入 (单词ID):
[ [ 25,  198, 3000, ... ],   # 句子1[ 1,   42,  999,  ... ] ]  # 句子2经过嵌入层后 (向量表示):
[ [ [0.2, -0.5, ..., 1.3],   # ID=25的向量[0.8, 0.1, ..., -0.9],   # ID=198的向量... ],[ [0.9, -0.2, ..., 0.4],   # ID=1的向量[0.3, 0.7, ..., -1.2],   # ID=42的向量... ] ]

为什么需要词嵌入:

  • 语义表示:相似的单词会有相似的向量表示

  • 降维:将离散的ID映射到连续空间(one-hot编码需要10000维 → 嵌入只需512维)

  • 可学习:在训练过程中,这些向量会不断调整以更好地表示语义关系

2.2  位置 Embedding(位置编码)

Transformer 的位置编码(Positional Encoding,PE)是模型的关键创新之一,它解决了传统序列模型(如 RNN)固有的顺序处理问题。Transformer 的自注意力机制本身不具备感知序列位置的能力,位置编码通过向输入嵌入添加位置信息,使模型能够理解序列中元素的顺序关系。位置编码计算之后的输出维度和词嵌入层相同,均为(X_{n\times d})。

位置编码的核心作用:

  1. 注入位置信息:让模型区分不同位置的相同单词(如 "bank" 在句首 vs 句尾)

  2. 保持距离关系:编码相对位置和绝对位置信息

  3. 支持并行计算:避免像 RNN 那样依赖顺序处理

为什么需要位置编码?

  1. 自注意力的位置不变性
    Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V,计算过程不包含位置信息

  2. 序列顺序的重要性

  • 自然语言:"猫追狗" ≠ "狗追猫"
  • 时序数据:股价序列的顺序决定趋势替代方案对比
方法优点缺点
正弦/余弦泛化性好,理论保证固定模式不灵活
可学习适应任务特定模式长度受限,需训练
相对位置直接建模相对距离实现复杂

位置编码的实际效果

  1. 早期层作用:帮助模型建立位置感知

  2. 后期层作用:位置信息被融合到语义表示中

  3. 可视化示例

Input:    [The,   cat,   sat,   on,   mat]
Embed:    [E_The, E_cat, E_sat, E_on, E_mat]
Position: [P0,    P1,    P2,    P3,   P4]Final: [E_The+P0, E_cat+P1, ... E_mat+P4]
(1)正余弦位置编码(论文采用)

正余弦位置编码的计算公式:

其中:

  •  `pos` 是token在序列中的位置(从0开始)
  •  `d_model` 是模型的嵌入维度(即每个token的向量维度)
  •  `i` 是维度的索引(从0到d_model/2-1)

特点:

  • 波长几何级数:覆盖不同频率
  • 相对位置可学习:位置偏移的线性变换 PE_{pos+k} 可表示为 PE_{pos} 的线性函数
  • 泛化性强:可处理比训练时更长的序列
  • 对称性:sin/cos 组合允许模型学习相对位置

代码实现:

class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)return x + self.pe[:len, :]

例如,指定emb_dim=512和max_len=100,句子长度为10,则位置embedding的数值计算如下(三角函数取弧度制):

\begin{bmatrix} sin\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{0}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{0}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{0}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{1}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{1}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{1}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{2}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{2}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{2}{10000^{\frac{510}{512}}} \right )\\ ... & ... & ... & ... & ... & ... & ...\\ sin\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{7}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{7}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{7}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{8}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{8}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{8}{10000^{\frac{510}{512}}} \right )\\ sin\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{0}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{2}{512}}} \right ) & ... & cos\left ( \frac{9}{10000^{\frac{508}{512}}} \right ) & sin\left ( \frac{9}{10000^{\frac{510}{512}}} \right ) & cos\left ( \frac{9}{10000^{\frac{510}{512}}} \right )\\ \end{bmatrix}_{10\times 512}=\begin{bmatrix} 0 & 1 & 0 & ... & 1 & 0 & 1\\ 0.8415 & 0.5403 & 0.8219 & ... & 1.0000 & 1.0366\times 10^{-4} & 1.0000\\ 0.9093 & -0.4161 & 0.9364 & ... & 1.0000 & 2.0733\times 10^{-4} & 1.0000\\ ... & ... & ... & ... & ... & ... & ...\\ 0.6570& 0.7539 & 0.4524 & ... & 1.0000 & 7.2564\times 10^{-4} & 1.0000\\ 0.9894 & -0.1455 & 0.9907 & ... & 1.0000 & 8.2931\times 10^{-4} & 1.0000\\ 0.4121 & -0.9111 & 0.6764 & ... & 1.0000 & 9.3297\times 10^{-4} & 1.0000 \end{bmatrix}_{10\times 512}

(2)可学习位置编码
class LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]

特性

  • 直接学习位置嵌入:作为模型参数训练
  • 灵活性高:可适应特定任务的位置模式
  • 长度受限:只能处理预定义的最大长度
  • 计算效率高:直接查表无需计算

三. Self-Attention(自注意力机制)和Multi-Head Attention(多头自注意力)

Transformer 的内部结构图,左侧为 Encoder block(编码器),右侧为 Decoder block(解码器)。可以看到:

(1)Encoder block 包含一个 Multi-Head Attention;

(2)Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接(Residual Connection),用于防止网络退化,Norm 表示Layer Normalization,用于对每一层的激活值进行归一化。

Multi-Head Attention 是 Transformer 的重点,它由 Self-Attention 演变而来,我们先从 Self-Attention 讲起。

3.1  Self-Attention(自注意力机制)

Self-Attention(自注意力)是 Transformer 架构的核心创新,它彻底改变了序列建模的方式。与传统的循环神经网络(RNN)和卷积神经网络(CNN)不同,self-attention 能够直接捕捉序列中任意两个元素之间的关系,无论它们之间的距离有多远:

Self-Attention 的输入用矩阵X_{n\times d}(n是序列长度,d是特征维度)进行表示,计算如下:

(1)通过可学习的权重矩阵生成Q(查询),K(键值),V(值):

\left\{\begin{matrix} Q = XW^Q \\ K = XW^K \\ V = XW^V \end{matrix}\right.

其中W^Q,W^K,W^V是可学习参数。

(2)计算 Self-Attention 的输出:Attention(Q,K,V)=softmax\left ( \frac{QK^{T}}{\sqrt{d_k}} \right )V

步骤分解:

  1. 相似度计算QK^T计算所有查询-键对之间的点积相似度,QK^T得到的矩阵行列数都为 n,n为句子单词数,这个矩阵可以表示单词之间的 attention 强度。

  2. 缩放:除以\sqrt{d_k}防止点积过大导致梯度消失

  3. 归一化:softmax 将相似度转换为概率分布

  4. 加权求和:用注意力权重对值向量加权求和,得到最终的输出

输入序列: [x1, x2, x3, x4]步骤1: 为每个输入生成Q,K,V向量
x1 → q1, k1, v1
x2 → q2, k2, v2
x3 → q3, k3, v3
x4 → q4, k4, v4步骤2: 计算注意力权重 (以x1为例)
权重1 = softmax(q1·k1 / √d_k)
权重2 = softmax(q1·k2 / √d_k)
权重3 = softmax(q1·k3 / √d_k)
权重4 = softmax(q1·k4 / √d_k)步骤3: 加权求和
输出1 = 权重1*v1 + 权重2*v2 + 权重3*v3 + 权重4*v4

3.2  Multi-Head Attention(多头注意力)

Transformer 使用多头机制增强模型表达能力:

MultiHead(Q,K,V)=Concat(head_1,head_2...head_h)W^O

其中每个注意力头:

head_i=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

  • h:注意力头的数量

  • W_i^Q, W_i^K, W_i^V:每个头的独立参数

  • W^O:输出投影矩阵

代码实现:

(1)多头分割处理:使用view将特征维度分割为多个头,确保每个头的维度:dim_head = dim_qk // num_heads

q = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)
k = ... # 类似处理
v = ... # 类似处理

(2)高效的矩阵运算:使用矩阵乘法并行计算所有位置的注意力分数

attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)

(3)多头合并:使用view合并多头:num_heads * d_v = dim_v

output = output.transpose(1, 2)
output = output.contiguous().view(-1, len_q, self.dim_v)

完整Multi-Head Attention(多头注意力)的代码实现,这里已经考虑了掩码处理的实现,关于掩码将在后面第5节介绍。

class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return output

四.Encoder (编码器)结构

上图红色部分是 Transformer 的 Encoder block (编码器)结构,可以看到是由 Multi-Head Attention, Add & Norm, Feed Forward Network, 第二个Add & Norm 组成的。刚刚已经了解了 Multi-Head Attention 的计算过程,现在介绍 Add & Norm 和 Feed Forward 部分。

4.1  Add & Norm层

它实际上由两个独立的操作组成:残差连接(Add) 和 层归一化(Layer Normalization, Norm)。这两个操作协同工作,极大地提升了 Transformer 的训练稳定性、收敛速度和最终性能。

(1)残差连接(Add)

目的: 解决深层神经网络中普遍存在的梯度消失/爆炸问题,并促进信息的直接流动。

操作: 将子层(自注意力层或 FFN)的输入直接加到该子层的输出上。

  • 公式:Output_Add = Sublayer_Input + Sublayer_Output(Sublayer_Input)

  • 其中 Sublayer_Output(Sublayer_Input) 代表自注意力或 FFN(Feed Forward Network,下一节介绍) 对输入进行计算后产生的结果。

直观理解: 想象一下学习一项新技能(子层要学习的新变换)。残差连接允许模型先保留已经掌握的旧技能(原始输入),然后在这个基础上只学习需要做出的增量调整修正(子层输出的变化量)。这样,即使新学的调整很小或者学习过程遇到困难,至少旧技能(原始信息)也能完整无损地传递下去。

关键优势:

  • 缓解梯度消失: 在反向传播时,梯度可以直接通过残差路径(加号)无损地流回浅层网络,确保深层参数也能得到有效的更新信号。

  • 信息高速公路: 为模型提供了一条“捷径”,使得网络即使很深,浅层的信息也能相对容易地传递到深层,反之亦然。

  • 模型更容易优化: 让网络更容易学习到一个恒等映射(即输出等于输入),这通常是深层网络一个不错的起点。如果子层需要做出改变,它只需要学习相对于输入的“残差”(Residual)即可。

 (2)层归一化(Layer Normalization, Norm)

目的: 稳定训练过程,加速收敛。它通过减少层内激活值的内部协变量偏移(Internal Covariate Shift)来实现这一点,即减少同一层内不同特征维度上激活值的分布变化。

操作: 对单个样本所有特征维度上进行归一化。

  • 计算该样本在该层所有神经元/特征维度上的激活值的均值(μ)和标准差(σ)。

  • 使用计算出的均值和标准差对该样本的所有激活值进行标准化Normalized_Activations = (Activations - μ) / √(σ² + ε) (ε 是一个很小的常数,防止除以零)。

  • 引入可学习的缩放参数 γ(gamma)和平移参数 β(beta),允许模型根据需要对标准化后的值进行缩放和偏移:Output_Norm = γ * Normalized_Activations + β。这非常重要,因为它赋予了模型恢复归一化可能丢失的表示能力(比如改变分布的形状)。

  • 与批量归一化(Batch Normalization, BN)的区别: BN 是在一个 Batch 内,对单个特征维度在所有样本上进行归一化(计算该特征在所有样本上的均值和方差)。而 Layer Norm 是在单个样本上,对所有特征维度进行归一化。

为什么 Layer Norm 更适合 Transformer/RNN?

  • 对 Batch Size 不敏感: LN 的计算不依赖于 Batch Size,即使 Batch Size 很小(甚至是 1)也能工作。BN 在小 Batch Size 下效果不稳定。

  • 处理变长序列: 在 NLP 任务中,序列长度通常是变化的。LN 在每个样本(一个序列)内部独立计算统计量,完美处理了变长问题。BN 需要处理同一特征在不同序列位置上的统计量,对于变长序列实现起来复杂且效果可能不佳。

  • 时序独立性: LN 对序列中不同时间步的归一化是独立的(虽然它们共享 γ 和 β),这更适合 RNN/Transformer 这种时序模型的结构。BN 在同一特征维度上混合了不同时间步的信息。

(3)为什么 Add & Norm 如此有效?(协同效应)

  1. 残差连接保信息、通梯度: Add 操作确保原始输入信息不被后续复杂的非线性变换完全覆盖,并提供了一条低阻力的梯度回传路径。

  2. 层归一化稳分布、加速收敛: Norm 操作接收来自 Add 的输出(原始输入 + 子层变换)。这个组合信号的分布可能会剧烈变化,尤其是在训练初期。LN 将其“拉回”到一个相对稳定(均值为0,方差为1)的分布,极大地减少了训练的不稳定性,显著加快了收敛速度。γ 和 β 让模型能自适应地调整这个分布。

  3. 顺序合理: 先 Add 后 Norm 是标准做法。先 Norm 后 Add 理论上也可以尝试,但实践表明效果不如前者。一个重要的原因是:Norm 操作(尤其是带 γ 和 β 的)本身就是一个复杂的、可学习的变换。如果把 Norm 放在残差分支里(即 Output = x + Norm(Sublayer(x))),那么残差路径就失去了其“纯净的恒等映射”的本质优势,梯度流可能再次受阻。

4.2  Feed Forward Network(FFN)层

FFN 层位于每个 Transformer 编码器层解码器层的内部。具体来说,在一个标准的层中,它紧跟在 Add & Norm 层之后(该 Add & Norm 层处理了自注意力或交叉注意力的输出)。

目的:

  • 引入非线性: 自注意力层主要执行的是线性变换(加权求和),即使有 Softmax 也是在不同位置之间计算权重。FFN 通过激活函数(如 ReLU、GELU)为模型引入了关键的非线性变换能力,极大地增强了模型的表达能力和拟合复杂函数的能力。

  • 特征变换与交互: 它对自注意力层输出的每个位置(token)的表示向量进行独立、相同的处理。它可以将注意力机制聚合到的上下文信息进行深度加工、转换和抽象,学习更复杂、更高级的特征表示。你可以理解为它对每个词向量进行了“独立思考”和“内部消化”。

  • 维度扩展与压缩: FFN 通常采用“瓶颈”结构(先升维再降维),允许模型在更高维度的空间中进行特征交互和变换,然后再投影回原始维度。这有助于捕获更丰富的模式。

  • 位置独立性: 关键点在于 Position-wise。它对序列中的每个位置(token)单独、并行地应用相同的操作。这意味着:

    • 处理位置 i 的向量时,不需要依赖位置 j (j ≠ i) 的向量。

    • 这使得 FFN 的计算可以高度并行化,效率非常高。

    • 不处理序列中不同位置之间的关系(这是自注意力层的任务),而是专注于单个位置内部的特征转换。

标准的 FFN 层由两个线性变换(全连接层) 和一个非线性激活函数组成,结构非常简单:

FFN(x)=ActivationFunction(x * W1 + b1) * W2 + b2

在论文中,激活函数采用RELU函数:

FFN(x)=max(0, x * W1 + b1) * W2 + b2

总结:

  • 核心作用: 引入非线性变换能力,对自注意力层聚合到的上下文信息进行深度加工和抽象

  • 关键特性: 位置独立(Position-wise)、高度并行化、采用瓶颈结构(升维 d_ff 再降维回 d_model)、通常是模型中参数量最大的部分。

  • 与自注意力的关系: 分工协作。自注意力处理位置间关系(沟通),FFN 处理位置内特征(思考)。

  • 为什么重要? 它是 Transformer 具备强大表达能力和能够学习复杂模式不可或缺的组件。没有非线性 FFN 的 Transformer 将退化为一个表达能力有限的线性模型。

4.3  完整的Encoder(编码器)架构

输入: x (嵌入向量 + 位置编码)

Multi-Head Self-Attention: attn_output = Attention(x)

Add & Norm 1: y = LayerNorm(x + attn_output)

Feed Forward Network: ffn_output = FFN(y) // y 就是 FFN 的输入 x

Add & Norm 2: output = LayerNorm(y + ffn_output)

第一个 Encoder block 的输入为句子单词的表示向量矩阵,后续 Encoder block 的输入是前一个 Encoder block 的输出,最后一个 Encoder block 输出的矩阵就是编码信息矩阵 C,这一矩阵后续会用到 Decoder (第5节)中。

代码实现如下:

class Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return x

五.Decoder (解码器)结构

 Transformer 的 Decoder block 结构,与 Encoder block 相似,但是存在一些区别:

  • 包含两个 Multi-Head Attention 层。
  • 第一个 Multi-Head Attention 层采用了 Masked 操作。
  • 第二个 Multi-Head Attention 层(又称交叉注意力,Cross-Attention)的K, V矩阵使用 Encoder 的编码信息矩阵C进行计算,而Q使用上一个 Decoder block 的输出计算。
  • 最后有一个 Softmax 层计算下一个翻译单词的概率。

5.1  Masked 操作

Decoder block 的第一个 Multi-Head Attention 采用了 Masked 操作,因为在翻译的过程中是顺序翻译的,即翻译完第 i 个单词,才可以翻译第 i+1 个单词。通过 Masked 操作可以防止第 i 个单词知道 i+1 个单词之后的信息。

如图,在 Decoder 的时候,是需要根据之前的翻译,求解当前最有可能的翻译,如下图所示。首先根据输入 "<Begin>" 预测出第一个单词为 "I",然后根据输入 "<Begin> I" 预测下一个单词 "have"。

Transformer 实现 Masked 操作的关键步骤在于修改注意力分数矩阵。具体步骤如下:

(1)计算 Query, Key, Value: 像普通的自注意力一样,解码器输入 X_dec 经过线性变换得到 Query (Q), Key (K), Value (V) 矩阵。

(2)计算注意力分数 (S): S = Q * K^T / sqrt(d_k)。此时的 S 是一个 [target_seq_len, target_seq_len] 的矩阵。S[i, j] 表示第 i 个目标位置(Query)与第 j 个目标位置(Key)之间的原始相关性分数。

  • 在这个矩阵中,S[i, j] 在 j > i 时,表示位置 i 的 Query 在关注位置 j 的 Key,而 j 在序列中位于 i 之后(未来)。这就是需要屏蔽的信息。

(3)应用掩码 (M): 构造一个与 S 形状相同的掩码矩阵 M

  • M[i, j] = 0 当 j <= i (允许看到当前位置 i 及之前 j <= i 的所有位置)。
  • M[i, j] = -inf (负无穷) 当 j > i (屏蔽当前位置 i 之后 j > i 的所有位置)。

(4)修改注意力分数: S_masked = S + M

  • 对于允许的位置 (j <= i),S_masked[i, j] = S[i, j] + 0 = S[i, j] (分数不变)。
  • 对于要屏蔽的位置 (j > i),S_masked[i, j] = S[i, j] + (-inf) = -inf (分数变为负无穷)。

(5)应用 Softmax: A = softmax(S_masked)

  • Softmax 函数会对输入进行指数运算。exp(-inf) = 0
  • 因此,对于被屏蔽的位置 (j > i),A[i, j] = 0
  • 对于允许的位置 (j <= i),A[i, j] 的值基于 S[i, j] 计算,其和为 1。

(6)计算输出: Output = A * V,由于屏蔽位置的注意力权重 A[i, j] = 0 (j > i),在计算位置 i 的输出向量时,只加权求和了位置 j <= i (当前位置及之前位置) 的 Value (V_j)。位置 j > i 的 Value 被乘以 0,对输出没有贡献。

可视化理解:

假设目标序列长度为 4:

注意力分数矩阵 S (计算后):
[[ s11, s12, s13, s14],[ s21, s22, s23, s24],[ s31, s32, s33, s34],[ s41, s42, s43, s44]]掩码矩阵 M (下三角为0,上三角为-inf):
[[ 0, -inf, -inf, -inf],[ 0,    0, -inf, -inf],[ 0,    0,    0, -inf],[ 0,    0,    0,    0]]掩码后的分数 S_masked = S + M:
[[ s11, -inf, -inf, -inf],   -> 位置1只能看到位置1[ s21,  s22, -inf, -inf],   -> 位置2只能看到位置1、2[ s31,  s32,  s33, -inf],   -> 位置3只能看到位置1、2、3[ s41,  s42,  s43,  s44]]   -> 位置4可以看到所有位置1、2、3、4应用 Softmax(S_masked) 得到注意力权重矩阵 A:
[[ a11,   0,   0,   0],   -> a11 = exp(s11)/exp(s11) = 1[ a21, a22,   0,   0],   -> a21 = exp(s21)/(exp(s21)+exp(s22)), a22 = exp(s22)/(exp(s21)+exp(s22))[ a31, a32, a33,   0],[ a41, a42, a43, a44]]计算输出:每一行只对非零权重对应的 Value 进行加权求和。

代码实现:

def attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False,  True,  True],[False, False,  True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return mask

除了上面介绍的掩码操作,这里再介绍另一种重要的掩码:填充掩码 (Padding Mask)。

目的: 处理变长序列。在批处理中,不同序列长度不同,需要填充 ([PAD]) 到相同长度。填充位置不应该参与注意力计算,也不应该获得注意力。

应用位置: 编码器解码器所有注意力层(自注意力和交叉注意力)。

实现:

  • 构造一个掩码向量/矩阵,标识哪些位置是真实 token (1),哪些是填充 token (0)。
  • 在计算注意力分数后、应用 Softmax 前: 对于需要屏蔽的位置(填充位置),将其注意力分数设置为一个很大的负数(如 -1e9)。
  • 应用 Softmax。被屏蔽位置的权重会变为 0。

与 Masked Self-Attention 的区别:

  • Masked Self-Attention (又称为Causal Mask): 屏蔽“未来”信息(位置 j > i),用于解码器自回归生成
  • Padding Mask: 屏蔽“无效”信息(填充位置),用于处理变长序列,在编码器和解码器的所有注意力层都需要。

代码实现:

def padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False,  True],[False, False,  True],[ True,  True,  True]],[[False,  True, False],[ True,  True,  True],[False,  True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return mask

5.2  Cross-Attention(交叉注意力)  

位置: 在 Transformer 解码器的每一层中,Cross-Attention 位于 Masked Multi-Head Self-Attention 层和 Add & Norm 层之后,以及前馈神经网络(FFN)层之前。即图中第二个Multi-Head Attention。

核心目的: 允许解码器在生成目标序列的当前词(位置)时,有选择地关注编码器输出的整个源序列(输入序列)的表示。这是序列到序列(Seq2Seq)任务(如机器翻译、文本摘要、语音识别)的核心机制。

直观理解: 想象你在翻译一句话。当你在写目标语言的某个词时,你需要参考源语言句子的哪些部分提供了最相关的信息?Cross-Attention 就是让解码器模型自动学习这种“软对齐”(Soft Alignment)的过程。它动态地为每个目标词位置计算一个权重分布,这个分布告诉模型源序列的哪些词或短语对生成当前目标词最重要。

输入来源:Cross-Attention 接收 两个不同来源 的输入:

(1)Query (Q): 来自解码器自身

  • 具体来源:是解码器上一层的输出,即经过 Masked Multi-Head Self-Attention 层 和随后的 Add & Norm 层 处理后的表示。
  • 代表什么? Q 代表了解码器当前的状态,特别是它正在尝试生成的目标序列当前位置的信息(包含了之前已生成目标词的信息以及解码器自身对该位置的理解)。
  • 形状: [batch_size, target_seq_len, d_model]target_seq_len 是目标序列的长度(可能包括起始符和填充)。

(2)Key (K) 和 Value (V): 来自编码器

  • 具体来源:是最后一个编码器层的输出(通常是该层 FFN 之后的 Add & Norm 层的输出)。
  • 代表什么? K 和 V 代表了编码器对整个源序列的完整理解。编码器已经将源序列的信息压缩、转换成了丰富的上下文表示。K 用于计算相关性,V 是实际被加权的信息。
  • 形状: [batch_size, source_seq_len, d_model]source_seq_len 是源序列的长度(可能包括填充)。
  • 关键点: K 和 V 是同一个张量!在标准的 Transformer 中,编码器的输出被同时用作 K 和 V 的输入(当然,它们会经过不同的线性变换矩阵 W^K 和 W^V 进行投影)。这体现了源序列的表示既用于计算与查询的相关性(K),也作为信息本身被提取(V)。

读者可以对比原论文中的图,需要特别注意的是,原论文Scaled Dot-Product Attention图中是按照Q,K,V排列,而Multi-Head Attention是按V,K,Q排列的,所以Query (Q): 来自解码器自身,Key (K) 和 Value (V): 来自编码器。

关键特性与意义:

  1. 信息桥梁: 这是唯一将编码器学到的源序列信息显式注入解码器的机制。没有它,解码器就只能基于目标序列自身的历史信息生成,无法完成翻译等任务。

  2. 动态软对齐: 其核心是为每个目标词位置动态计算一个在整个源序列上的注意力分布。这模仿了人类翻译/理解时对原文不同部分的关注变化。模型自动学习这种对齐关系,无需预设规则。

  3. 多头机制: 不同的注意力头可以学习关注源序列的不同方面(如词义、语法结构、语义角色),从而捕捉更丰富的对齐模式。

  4. 位置独立性 (对源序列): 解码器在计算目标位置 i 的上下文时,可以访问源序列的任何位置 j(1 到 source_seq_len),不受位置顺序限制。模型完全根据相关性决定关注哪里。

  5. 无 Causal Masking: 再次强调,Cross-Attention 不需要也不应该使用 Causal Mask (屏蔽未来)。源序列是完整已知的输入,解码器需要利用其所有信息来预测当前目标词。

  6. 与 Masked Self-Attention 的协同: Masked Self-Attention 让解码器专注于已生成的目标序列历史。Cross-Attention 则让解码器基于这个“当前状态”去查询相关的源信息。两者结合,解码器才能既保持目标序列的连贯性,又确保内容与源序列一致。

5.3  完整的Decoder(解码器)架构

一个完整的解码器层流程如下:

(1)输入: 嵌入的目标序列 + 位置编码(训练时可能是整个目标序列右移,预测时是已生成序列)。

(2)Masked Multi-Head Self-Attention:

  • 输入:嵌入的目标序列。
  • 作用:让目标序列的每个位置关注其之前的所有位置(防止信息泄露)。
  • 输出:MaskedAttn_Output

(3)Add & Norm 1:Residual1 = LayerNorm(Embedded_Target + MaskedAttn_Output)

(4)Multi-Head Encoder-Decoder Attention (Cross-Attention):

  • Q = Residual1 (来自解码器自身,代表当前目标位置状态)。
  • K = Encoder_OutputV = Encoder_Output (来自编码器,代表完整源序列信息)。
  • 计算如上所述,输出 CrossAttn_Output

(5)Add & Norm 2:Residual2 = LayerNorm(Residual1 + CrossAttn_Output)

(6)Position-wise Feed Forward Network:输入:Residual2,输出:FFN_Output

(7)Add & Norm 3:Decoder_Layer_Output = LayerNorm(Residual2 + FFN_Output)

(8)输出: Decoder_Layer_Output 传递给下一层解码器或最终的线性层+Softmax。

代码实现:

class DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return x

六.Transformer完整代码实现

import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):# Learnable positional encodingdef __init__(self, emb_dim, len):super(LearnablePositionalEncoding, self).__init__()assert emb_dim > 0 and len > 0, 'emb_dim and len must be positive'self.emb_dim = emb_dimself.len = lenself.pe = nn.Parameter(torch.zeros(len, emb_dim))def forward(self, x):return x + self.pe[:x.size(-2), :]class PositionalEncoding(nn.Module):# Sine-cosine positional codingdef __init__(self, emb_dim, max_len, freq=10000.0):super(PositionalEncoding, self).__init__()assert emb_dim > 0 and max_len > 0, 'emb_dim and max_len must be positive'self.emb_dim = emb_dimself.max_len = max_lenself.pe = torch.zeros(max_len, emb_dim)pos = torch.arange(0, max_len).unsqueeze(1)# pos: [max_len, 1]div = torch.pow(freq, torch.arange(0, emb_dim, 2) / emb_dim)# div: [ceil(emb_dim / 2)]self.pe[:, 0::2] = torch.sin(pos / div)# torch.sin(pos / div): [max_len, ceil(emb_dim / 2)]self.pe[:, 1::2] = torch.cos(pos / (div if emb_dim % 2 == 0 else div[:-1]))# torch.cos(pos / div): [max_len, floor(emb_dim / 2)]def forward(self, x, len=None):if len is None:len = x.size(-2)print(self.pe[:len, :])return x + self.pe[:len, :]class MultiHeadAttention(nn.Module):def __init__(self, dim, dim_qk=None, dim_v=None, num_heads=1, dropout=0.):super(MultiHeadAttention, self).__init__()dim_qk = dim if dim_qk is None else dim_qkdim_v = dim if dim_v is None else dim_vassert dim % num_heads == 0 and dim_v % num_heads == 0 and dim_qk % num_heads == 0, 'dim must be divisible by num_heads'self.dim = dimself.dim_qk = dim_qkself.dim_v = dim_vself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.w_q = nn.Linear(dim, dim_qk)self.w_k = nn.Linear(dim, dim_qk)self.w_v = nn.Linear(dim, dim_v)def forward(self, q, k, v, mask=None):# q: [B, len_q, D]# k: [B, len_kv, D]# v: [B, len_kv, D]assert q.ndim == k.ndim == v.ndim == 3, 'input must be 3-dimensional'len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)assert q.size(-1) == k.size(-1) == v.size(-1) == self.dim, 'dimension mismatch'assert len_k == len_v, 'len_k and len_v must be equal'len_kv = len_vq = self.w_q(q).view(-1, len_q, self.num_heads, self.dim_qk // self.num_heads)k = self.w_k(k).view(-1, len_kv, self.num_heads, self.dim_qk // self.num_heads)v = self.w_v(v).view(-1, len_kv, self.num_heads, self.dim_v // self.num_heads)# q: [B, len_q, num_heads, dim_qk//num_heads]# k: [B, len_kv, num_heads, dim_qk//num_heads]# v: [B, len_kv, num_heads, dim_v//num_heads]# The following 'dim_(qk)//num_heads' is writen as d_(qk)q = q.transpose(1, 2)k = k.transpose(1, 2)v = v.transpose(1, 2)# q: [B, num_heads, len_q, d_qk]# k: [B, num_heads, len_kv, d_qk]# v: [B, num_heads, len_kv, d_v]attn = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_qk ** 0.5)# attn: [B, num_heads, len_q, len_kv]if mask is not None:attn = attn.transpose(0, 1).masked_fill(mask, float('-1e20')).transpose(0, 1)attn = torch.softmax(attn, dim=-1)attn = self.dropout(attn)output = torch.matmul(attn, v)# output: [B, num_heads, len_q, d_v]output = output.transpose(1, 2)# output: [B, len_q, num_heads, d_v]output = output.contiguous().view(-1, len_q, self.dim_v)# output: [B, len_q, num_heads * d_v] = [B, len_q, dim_v]return outputclass Feedforward(nn.Module):def __init__(self, dim, hidden_dim=2048, dropout=0., activate=nn.ReLU()):super(Feedforward, self).__init__()self.dim = dimself.hidden_dim = hidden_dimself.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, dim)self.act = activatedef forward(self, x):x = self.act(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xdef attn_mask(len):""":param len: length of sequence:return: mask tensor, False for not replaced, True for replaced as -infe.g. attn_mask(3) =tensor([[[False,  True,  True],[False, False,  True],[False, False, False]]])"""mask = torch.triu(torch.ones(len, len, dtype=torch.bool), 1)return maskdef padding_mask(pad_q, pad_k):""":param pad_q: pad label of query (0 is padding, 1 is not padding), [B, len_q]:param pad_k: pad label of key (0 is padding, 1 is not padding), [B, len_k]:return: mask tensor, False for not replaced, True for replaced as -infe.g. pad_q = tensor([[1, 1, 0]], [1, 0, 1])padding_mask(pad_q, pad_q) =tensor([[[False, False,  True],[False, False,  True],[ True,  True,  True]],[[False,  True, False],[ True,  True,  True],[False,  True, False]]])"""assert pad_q.ndim == pad_k.ndim == 2, 'pad_q and pad_k must be 2-dimensional'assert pad_q.size(0) == pad_k.size(0), 'batch size mismatch'mask = pad_q.bool().unsqueeze(2) * pad_k.bool().unsqueeze(1)mask = ~mask# mask: [B, len_q, len_k]return maskclass EncoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(EncoderLayer, self).__init__()self.attn = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)def forward(self, x, mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn(res1, res1, res1, mask)res2 = self.norm2(x)x = x + self.ffn(res2)else:x = self.attn(x, x, x, mask) + xx = self.norm1(x)x = self.ffn(x) + xx = self.norm2(x)return xclass Encoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, mask=None):for layer in self.layers:x = layer(x, mask)return xclass DecoderLayer(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, dropout=0., pre_norm=False):super(DecoderLayer, self).__init__()self.attn1 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.attn2 = MultiHeadAttention(dim, dim_qk=dim_qk, num_heads=num_heads, dropout=dropout)self.ffn = Feedforward(dim, dim * 4, dropout)self.pre_norm = pre_normself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, enc, self_mask=None, pad_mask=None):if self.pre_norm:res1 = self.norm1(x)x = x + self.attn1(res1, res1, res1, self_mask)res2 = self.norm2(x)x = x + self.attn2(res2, enc, enc, pad_mask)res3 = self.norm3(x)x = x + self.ffn(res3)else:x = self.attn1(x, x, x, self_mask) + xx = self.norm1(x)x = self.attn2(x, enc, enc, pad_mask) + xx = self.norm2(x)x = self.ffn(x) + xx = self.norm3(x)return xclass Decoder(nn.Module):def __init__(self, dim, dim_qk=None, num_heads=1, num_layers=1, dropout=0., pre_norm=False):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(dim, dim_qk, num_heads, dropout, pre_norm) for _ in range(num_layers)])def forward(self, x, enc, self_mask=None, pad_mask=None):for layer in self.layers:x = layer(x, enc, self_mask, pad_mask)return xclass Transformer(nn.Module):def __init__(self, dim, vocabulary, num_heads=1, num_layers=1, dropout=0., learnable_pos=False, pre_norm=False):super(Transformer, self).__init__()self.dim = dimself.vocabulary = vocabularyself.num_heads = num_headsself.num_layers = num_layersself.dropout = dropoutself.learnable_pos = learnable_posself.pre_norm = pre_normself.embedding = nn.Embedding(vocabulary, dim)self.pos_enc = LearnablePositionalEncoding(dim, 100) if learnable_pos else PositionalEncoding(dim, 100)self.encoder = Encoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.decoder = Decoder(dim, dim // num_heads, num_heads, num_layers, dropout, pre_norm)self.linear = nn.Linear(dim, vocabulary)def forward(self, src, tgt, src_mask=None, tgt_mask=None, pad_mask=None):# src.shape: torch.Size([2, 10])src = self.embedding(src)# src.shape: torch.Size([2, 10, 512])src = self.pos_enc(src)# src.shape: torch.Size([2, 10, 512])src = self.encoder(src, src_mask)# src.shape: torch.Size([2, 10, 512])# tgt.shape: torch.Size([2, 8])tgt = self.embedding(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.pos_enc(tgt)# tgt.shape: torch.Size([2, 8, 512])tgt = self.decoder(tgt, src, tgt_mask, pad_mask)# tgt.shape: torch.Size([2, 8, 512])output = self.linear(tgt)# output.shape: torch.Size([2, 8, 10000])return outputdef get_mask(self, tgt, src_pad=None):# Under normal circumstances, tgt_pad will perform mask processing when calculating loss, and it isn't necessarily in decoderif src_pad is not None:src_mask = padding_mask(src_pad, src_pad)else:src_mask = Nonetgt_mask = attn_mask(tgt.size(1))if src_pad is not None:pad_mask = padding_mask(torch.zeros_like(tgt), src_pad)else:pad_mask = None# src_mask: [B, len_src, len_src]# tgt_mask: [len_tgt, len_tgt]# pad_mask: [B, len_tgt, len_src]return src_mask, tgt_mask, pad_maskif __name__ == '__main__':model = Transformer(dim=512, vocabulary=10000, num_heads=8, num_layers=6, dropout=0.1, learnable_pos=False, pre_norm=True)src = torch.randint(0, 10000, (2, 10))  # torch.Size([2, 10])tgt = torch.randint(0, 10000, (2, 8))   # torch.Size([2, 8])src_pad = torch.randint(0, 2, (2, 10))  # torch.Size([2, 10])src_mask, tgt_mask, pad_mask = model.get_mask(tgt, src_pad)model(src, tgt, src_mask, tgt_mask, pad_mask)# output.shape: torch.Size([2, 8, 10000])

相关文章:

  • Java面试宝典:基础三
  • 新生代潜力股刘小北:演艺路上的璀璨新星
  • 用户行为序列建模(篇七)-【阿里】DIN
  • Linux下基于C++11的socket网络编程(基础)个人总结版
  • 学习日志02 ETF 基础数据可视化分析与简易管理系统
  • BERT 模型详解:结构、原理解析
  • 视频跳帧播放器设计与实现
  • Java I/O 模型详解:BIO、NIO 和 AIO
  • [Python 基础课程]Hello World
  • 实战四:基于PyTorch实现猫狗分类的web应用【2/3】
  • 【Linux庖丁解牛】— 文件系统!
  • 什么是RAG检索生成增强?
  • 利用deepseek学术搜索
  • 「Java案例」华氏摄氏温度转换
  • XIP (eXecute In Place)
  • 双指针的用法
  • Nginx漏洞处理指南
  • [database] Closure computation | e-r diagram | SQL
  • llama.cpp学习笔记:后端加载
  • VMware设置虚拟机为固定IP