自注意力机制(Self-Attention)简介
Transformer 是一种基于**自注意力机制(Self-Attention)**的深度学习模型架构,最初由 Google 在 2017 年的论文《Attention is All You Need》中提出。它彻底改变了自然语言处理(NLP)领域,成为 BERT、GPT 等大模型的基础。
一、Transformer 原理详解
1. 模型整体结构
Transformer 是一个Encoder-Decoder 架构,由以下主要模块组成:
Input → Embedding + Positional Encoding → Encoder → Decoder → Output
- Encoder:由 N 个相同的层堆叠而成(论文中 N=6)
- Decoder:也由 N 个相同的层堆叠而成
- 每一层都包含:
- 多头自注意力机制(Multi-Head Self-Attention)
- 前馈神经网络(Feed-Forward Network)
- 残差连接(Residual Connection)和层归一化(LayerNorm)
2. 核心组件详解
(1)Self-Attention(自注意力)
自注意力机制允许模型在处理序列时,关注序列中不同位置的信息。
给定输入向量序列 X=(x1,x2,...,xn)X = (x_1, x_2, ..., x_n)X=(x1,x2,...,xn),对每个位置 iii,计算其输出为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中:
- Q=XWQQ = XW_QQ=XWQ, K=XWKK = XW_KK=XWK, V=XWVV = XW_VV=XWV
- dkd_kdk:是 K 的维度,用于缩放防止梯度消失
(2)Multi-Head Attention(多头注意力)
将 Q、K、V 投影到多个子空间,分别进行注意力计算,然后拼接:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(head_1, ..., head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中:
- headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)
多头让模型在不同表示子空间中学习不同特征。
(3)Positional Encoding(位置编码)
由于 Transformer 没有 RNN 或 CNN 的顺序结构,必须加入位置信息。使用正弦和余弦函数生成:
PE(pos,2i)=sin(pos/100002i/dmodel)PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel)
PE(pos,2i+1)=cos(pos/100002i/dmodel)PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)
(4)Feed-Forward Network(前馈网络)
每个位置独立地通过两个线性变换和一个 ReLU 激活:
FFN(x)=W2⋅ReLU(W1x+b1)+b2FFN(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2 FFN(x)=W2⋅ReLU(W1x+b1)+b2
(5)残差连接与层归一化
每一子层输出为:
LayerNorm(x+Sublayer(x))\text{LayerNorm}(x + \text{Sublayer}(x)) LayerNorm(x+Sublayer(x))
二、PyTorch 实现一个简化版 Transformer
下面实现一个简化版的 Transformer 模型,用于序列到序列任务(如机器翻译)。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math# -----------------------------
# 1. Positional Encoding
# -----------------------------
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0) # (1, max_len, d_model)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:, :x.size(1), :]# -----------------------------
# 2. Transformer Model
# -----------------------------
class TransformerModel(nn.Module):def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, max_len=5000, dropout=0.1):super(TransformerModel, self).__init__()self.d_model = d_model# Embedding + Positional Encodingself.embedding = nn.Embedding(vocab_size, d_model)self.pos_encoder = PositionalEncoding(d_model, max_len)# Transformerself.transformer = nn.Transformer(d_model=d_model,nhead=nhead,num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers,dim_feedforward=dim_feedforward,dropout=dropout,batch_first=True # 使用 batch_first=True 更直观)# Output projectionself.fc_out = nn.Linear(d_model, vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):"""src: (batch_size, src_len)tgt: (batch_size, tgt_len)"""# Embedding + Positional Encodingsrc_emb = self.dropout(self.pos_encoder(self.embedding(src) * math.sqrt(self.d_model)))tgt_emb = self.dropout(self.pos_encoder(self.embedding(tgt) * math.sqrt(self.d_model)))# Transformeroutput = self.transformer(src=src_emb,tgt=tgt_emb,src_mask=src_mask,tgt_mask=tgt_mask,src_key_padding_mask=src_key_padding_mask,tgt_key_padding_mask=tgt_key_padding_mask)# Outputlogits = self.fc_out(output)return logitsdef generate_square_subsequent_mask(self, sz):"""生成因果掩码(防止看到未来信息)"""mask = torch.triu(torch.ones(sz, sz), diagonal=1)mask = mask.masked_fill(mask == 1, float('-inf'))return mask# -----------------------------
# 3. 使用示例
# -----------------------------
if __name__ == "__main__":# 参数vocab_size = 1000d_model = 512max_seq_len = 10# 创建模型model = TransformerModel(vocab_size=vocab_size, d_model=d_model)# 模拟数据src = torch.randint(1, 100, (2, max_seq_len)) # (batch, src_len)tgt = torch.randint(1, 100, (2, max_seq_len)) # (batch, tgt_len)# 生成目标序列的因果掩码tgt_mask = model.generate_square_subsequent_mask(max_seq_len).to(src.device)# 前向传播output = model(src, tgt, tgt_mask=tgt_mask)print("Output shape:", output.shape) # (2, 10, 1000)
三、说明
1. 关键点解释
batch_first=True
:让输入形状为(batch, seq_len, d_model)
,更符合直觉。math.sqrt(d_model)
:缩放嵌入,稳定训练。generate_square_subsequent_mask
:防止解码器看到未来 token。nn.Transformer
是 PyTorch 内置的高效实现,封装了 Encoder 和 Decoder。
四、扩展(手动实现 Multi-Head Attention)
如果你想从零实现 Multi-Head Attention,可以替换 nn.Transformer
中的部分:
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn = F.softmax(scores, dim=-1)return torch.matmul(attn, V)def split_heads(self, x, batch_size):x = x.view(batch_size, -1, self.num_heads, self.d_k)return x.transpose(1, 2)def combine_heads(self, x, batch_size):x = x.transpose(1, 2).contiguous()return x.view(batch_size, -1, self.d_model)def forward(self, Q, K, V, mask=None):batch_size = Q.size(0)Q = self.split_heads(self.W_q(Q), batch_size)K = self.split_heads(self.W_k(K), batch_size)V = self.split_heads(self.W_v(V), batch_size)attn = self.scaled_dot_product_attention(Q, K, V, mask)output = self.combine_heads(attn, batch_size)return self.W_o(output)
五、总结
组件 | 作用 |
---|---|
Self-Attention | 捕捉序列内部依赖 |
Multi-Head | 多子空间学习不同特征 |
Positional Encoding | 加入位置信息 |
FFN | 非线性变换 |
Residual + LayerNorm | 稳定训练 |
Transformer 的核心思想是:用注意力机制替代循环和卷积,实现并行化和长距离依赖建模。