手动实现 Transformer 模型
本文使用 Pytorch 库手动实现了传统 Transformer 模型中的多头自注意力机制、残差连接和层归一化、前馈层、编码器、解码器等子模块,进而实现了对 Transformer 模型的构建。
"""
@Title: 解析 Transformer
@Time: 2025/5/10
@Author: Michael Jie
"""import mathimport torch
import torch.nn.functional as F
from torch import nn, Tensor# 缩放点积注意力机制 (Scaled Dot-Product Attention)
class Attention(nn.Module):def __init__(self, causal: bool = True) -> None:"""注意力公式:Attention(Q, K, V) = softmax(Q · K / sqrt(d_k)) · VArgs:causal: 是否自动生成因果掩码,默认为 True"""super(Attention, self).__init__()self.causal = causaldef forward(self,q: Tensor,k: Tensor,v: Tensor,padding_mask: Tensor = None,attn_mask: Tensor = None) -> tuple[Tensor, Tensor]:"""填充掩码:处理变长序列,避免填充影响注意力计算因果掩码:防止解码器在训练时看到未来的信息Args:q: 查询 shape=(..., seq_len_q, d_k)k: 键 shape=(..., seq_len_k, d_k)v: 值 shape=(..., seq_len_k, d_v)padding_mask: 填充掩码 shape=(..., seq_len_k)attn_mask: 因果掩码 shape=(..., seq_len_q, seq_len_k)Returns:output: 输出 shape=(..., seq_len_q, d_v)weights: 注意力权重 shape=(..., seq_len_q, seq_len_k)"""# 注意力分数d_k = q.size(-1)scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 应用填充掩码if padding_mask is not None:# 广播 (..., 1, seq_len_k)scores = scores.masked_fill(padding_mask.unsqueeze(-2), float("-inf"))# 自动生成因果掩码,优先使用自定义的因果掩码seq_len_q, seq_len_k = q.size(-2), k.size(-2)if self.causal and attn_mask is None:attn_mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()# 应用因果掩码if attn_mask is not None:scores = scores.masked_fill(attn_mask, float("-inf"))# 注意力权重weights = F.softmax(scores, dim=-1)# 再次应用填充掩码,确保填充位置的注意力权重为 0if padding_mask is not None:weights = weights.masked_fill(padding_mask.unsqueeze(-2), 0)# 乘以 v 得到输出output = torch.matmul(weights, v)return output, weights# 自注意力机制 (Self Attention)
class SelfAttention(nn.Module):def __init__(self, d_model: int = 512) -> None:"""自注意力机制是注意力机制的一种特殊形式,其中 Q、K、V 都来自同一输入序列,其能够捕捉序列内部元素之间的关系,不依赖 RNN 或 CNN,直接建模长距离依赖。Args:d_model: 特征维度,默认为 512"""super(SelfAttention, self).__init__()self.attention = Attention() # 注意力机制# 合并 Q、K、V 的线性变换self.linear_qkv = nn.Linear(d_model, d_model * 3)self.linear_out = nn.Linear(d_model, d_model)def forward(self,x: Tensor,padding_mask: Tensor = None,attn_mask: Tensor = None) -> Tensor:"""_Args:x: 词嵌入 shape=(batch_size, seq_len, d_model)padding_mask: 填充掩码 shape=(batch_size, seq_len)attn_mask: 因果掩码 shape=(seq_len, seq_len)Returns:output: 输出 shape=(batch_size, seq_len, d_model)"""# 通过线性层同时生成 Q、K、Vqkv = self.linear_qkv(x)q, k, v = qkv.chunk(3, dim=-1) # (batch_size, seq_len, d_model)# 应用注意力机制output, weights = self.attention(q, k, v, padding_mask, attn_mask)return self.linear_out(output)# 多头自注意力机制 (Multi-Head Self Attention)
class MultiHeadSelfAttention(nn.Module):def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:"""多头自注意力机制是自注意力机制的扩展,通过将输入特征分割成多个头,每个头独立计算注意力,然后将结果拼接起来,从而提高模型的多角度表达能力。Args:d_model: 特征维度,默认为 512num_heads: 头数,默认为 8"""super(MultiHeadSelfAttention, self).__init__()if d_model % num_heads != 0:raise ValueError(f"d_model must be divisible by num_heads, but got {d_model} and {num_heads}")self.num_heads = num_headsself.attention = Attention() # 注意力机制# 分别对 Q、K、V 进行线性变换self.linear_q = nn.Linear(d_model, d_model)self.linear_k = nn.Linear(d_model, d_model)self.linear_v = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)def forward(self,q: Tensor,k: Tensor,v: Tensor,padding_mask: Tensor = None,attn_mask: Tensor = None) -> Tensor:"""Q、K、V 在不同的自注意力模块中的来源可能不同,在编解码器自注意力中,Q 来自解码器的输入,K、V 来自编码器的输出。Args:q: 查询 shape=(batch_size, seq_len, d_model)k: 键 shape=(batch_size, seq_len / seq_len_k, d_model)v: 值 shape=(batch_size, seq_len / seq_len_k, d_model)padding_mask: 填充掩码 shape=(batch_size, seq_len / seq_len_k)attn_mask: 因果掩码 shape=(seq_len / seq_len_k, seq_len / seq_len_k)Returns:output: 输出 shape=(batch_size, seq_len, d_model)"""q = self.linear_q(q)k = self.linear_k(k)v = self.linear_v(v)batch_size, seq_len, seq_len_k = q.size(0), q.size(1), k.size(1)# (batch_size, num_heads, seq_len, d_k)q = q.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)k = k.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)v = v.view(batch_size, seq_len_k, self.num_heads, -1).transpose(1, 2)# 调整掩码形状以匹配多头if padding_mask is not None:padding_mask = padding_mask.unsqueeze(1) # (batch_size, 1, seq_len)if attn_mask is not None:attn_mask = attn_mask.unsqueeze(0) # (1, seq_len, seq_len)# 应用注意力机制output, weights = self.attention(q, k, v, padding_mask, attn_mask)# 拼接output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)return self.linear_out(output)# 残差连接和层归一化 (Add&Norm)
class AddNorm(nn.Module):def __init__(self, d_model: int = 512) -> None:"""Add&Norm 层结合了两种操作:残差连接和层归一化,可以使模型在训练过程中更加稳定,并且通过堆叠多个这样的层来构建更深的模型。Args:d_model: 特征维度,默认为 512"""super(AddNorm, self).__init__()self.norm = nn.LayerNorm(d_model) # 层归一化def forward(self, x: Tensor, y: Tensor) -> Tensor:return self.norm(x + y)# 前馈层 (FeedForward Layer)
class FeedForward(nn.Module):def __init__(self,input_dim: int = 512,hidden_dim: int = 2048,activation: str = "relu",dropout: float = 0.1) -> None:"""全连接层(扩大维度) -> 激活函数 -> 全连接层(恢复原始维度)通过非线性变换进一步提取和增强特征,使模型具备更强的模式识别和语义组合能力。Args:input_dim: 输入维度,默认为 512hidden_dim: 隐藏层维度,默认为 2048activation: 激活函数,默认为 "relu"-支持:"sigmoid", "tanh", "relu", "gelu", "leaky_relu", "elu"dropout: 丢弃率,默认为 0.1"""super(FeedForward, self).__init__()match activation: # 切换不同的激活函数case "sigmoid":activation = nn.Sigmoid()case "tanh":activation = nn.Tanh()case "relu":activation = nn.ReLU()case "gelu":activation = nn.GELU()case "leaky_relu":activation = nn.LeakyReLU()case "elu":activation = nn.ELU()case _:raise ValueError(f"Unsupported activation function: {activation}")# Linear -> activation -> Dropout -> Linearself.ffn = nn.Sequential(nn.Linear(input_dim, hidden_dim),activation,nn.Dropout(dropout),nn.Linear(hidden_dim, input_dim),)def forward(self, x: Tensor) -> Tensor:return self.ffn(x)# 编码层
class EncoderLayer(nn.Module):def __init__(self,d_model: int = 512,num_heads: int = 8,dim_feedforward: int = 2048,dropout: float = 0.1) -> None:"""MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNormArgs:d_model: 特征维度,默认为 512num_heads: 头数,默认为 8dim_feedforward: FFN 隐藏层维度,默认为 2048dropout: 丢弃率,默认为 0.1"""super(EncoderLayer, self).__init__()# 多头自注意力层self.attn = MultiHeadSelfAttention(d_model, num_heads)# Add&Norm 层self.norm1 = AddNorm(d_model)self.norm2 = AddNorm(d_model)# 前馈层self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)# 丢弃层self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self,x: Tensor,padding_mask: Tensor = None,attn_mask: Tensor = None) -> Tensor:x = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask, attn_mask)))x = self.norm2(x, self.dropout2(self.ffn(x)))return x# 编码器
class Encoder(nn.Module):def __init__(self, num_layers: int = 6, **params) -> None:"""编码器由多个编码层组成,每个编码层结构相同但并不共享参数。Args:num_layers: 层数,默认为 6**params: 编码层参数,参考 EncoderLayer"""super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(**params)for _ in range(num_layers)])def forward(self,x: Tensor,padding_mask: Tensor = None,attn_mask: Tensor = None) -> Tensor:for layer in self.layers: # 逐层传递x = layer(x, padding_mask, attn_mask)return x# 解码层
class DecoderLayer(nn.Module):def __init__(self,d_model: int = 512,num_heads: int = 8,dim_feedforward: int = 2048,dropout: float = 0.1) -> None:"""MultiHeadSelfAttention -> AddNorm -> MultiHeadSelfAttention -> AddNorm -> FeedForward -> AddNormArgs:d_model: 特征维度,默认为 512num_heads: 头数,默认为 8dim_feedforward: FFN 隐藏层维度,默认为 2048dropout: 丢弃率,默认为 0.1"""super(DecoderLayer, self).__init__()# 多头自注意力层self.attn = MultiHeadSelfAttention(d_model, num_heads)self.cross_attn = MultiHeadSelfAttention(d_model, num_heads)# Add&Norm 层self.norm1 = AddNorm(d_model)self.norm2 = AddNorm(d_model)self.norm3 = AddNorm(d_model)# 前馈层self.ffn = FeedForward(d_model, dim_feedforward, dropout=dropout)# 丢弃层self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)def forward(self,y: Tensor,memory: Tensor,padding_mask_y: Tensor = None,padding_mask_memory: Tensor = None,attn_mask_y: Tensor = None,attn_mask_memory: Tensor = None) -> None:x = yx = self.norm1(x, self.dropout1(self.attn(x, x, x, padding_mask_y, attn_mask_y)))x = self.norm2(x, self.dropout2(self.attn(x, memory, memory, padding_mask_memory, attn_mask_memory)))x = self.norm3(x, self.dropout3(self.ffn(x)))return x# 解码器
class Decoder(nn.Module):def __init__(self, num_layers: int = 6, **params) -> None:"""解码器由多个解码层组成,每个解码层结构相同但并不共享参数。Args:num_layers: 层数,默认为 6**params: 解码层参数,参考 DecoderLayer"""super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(**params)for _ in range(num_layers)])def forward(self,y: Tensor,memory: Tensor,padding_mask_y: Tensor = None,padding_mask_memory: Tensor = None,attn_mask_y: Tensor = None,attn_mask_memory: Tensor = None) -> Tensor:x = yfor layer in self.layers: # 逐层传递x = layer(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)return x# Transformer
class Transformer(nn.Module):def __init__(self,num_encoder_layers: int = 6,num_decoder_layers: int = 6,**params) -> None:"""transformer 是标准的编码器-解码器结构Args:num_encoder_layers: 编码器层数,默认为 6num_decoder_layers: 解码器层数,默认为 6**params: 编解码层参数,参考 EncoderLayer 和 DecoderLayer"""super(Transformer, self).__init__()self.encoder = Encoder(num_encoder_layers, **params) # 编码器self.decoder = Decoder(num_decoder_layers, **params) # 解码器def forward(self,x: Tensor,y: Tensor,padding_mask_x: Tensor = None,padding_mask_y: Tensor = None,padding_mask_memory: Tensor = None,attn_mask_x: Tensor = None,attn_mask_y: Tensor = None,attn_mask_memory: Tensor = None) -> Tensor:memory = self.encoder(x, padding_mask_x, attn_mask_x)output = self.decoder(y, memory, padding_mask_y, padding_mask_memory, attn_mask_y, attn_mask_memory)return outputif __name__ == '__main__':# attention = Attention(True)# t1, t2 = attention(# torch.rand((2, 3, 64)),# torch.rand((2, 5, 64)),# torch.rand((2, 5, 512)),# torch.tensor([[False, True, True, True, True],# [False, False, False, False, True]])# )# print(t1.shape, t2.shape)# self_attention = SelfAttention()# t3 = self_attention(# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t3.shape)# multi_head_self_attention = MultiHeadSelfAttention(num_heads=2)# t4 = multi_head_self_attention(# torch.rand((2, 3, 512)),# torch.rand((2, 5, 512)),# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t4.shape)# encoder_layer = EncoderLayer()# t5 = encoder_layer(# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t5.shape)# encoder = Encoder(dropout=0.2)# t6 = encoder(# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t6.shape)# decoder_layer = DecoderLayer()# t7 = decoder_layer(# torch.rand((2, 3, 512)),# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False],# [False, False, True]]),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t7.shape)# decoder = Decoder()# t8 = decoder(# torch.rand((2, 3, 512)),# torch.rand((2, 5, 512)),# torch.tensor([[False, False, False],# [False, False, True]]),# torch.tensor([[False, False, False, True, True],# [False, False, True, True, True]])# )# print(t8.shape)transformer = Transformer()t9 = transformer(torch.rand((2, 5, 512)),torch.rand((2, 3, 512)),torch.tensor([[False, False, False, True, True],[False, False, True, True, True]]),torch.tensor([[False, False, False],[False, False, True]]),)print(t9.shape)