Transformer中的三种注意力机制
原本想学习一下交叉注意力,结果发现交叉注意力并不是一种特殊的注意力模型,只是Transformer中的一种注意力,因此这里都整理一下,方便以后查阅。
Transformer中的三种注意力机制
- 1. Self-Attention(自注意力机制)
- 1.1 定义
- Scaled Dot-Product Attention(缩放点积注意力)
- Self Attention(自注意力)
- Multi-Head Self Attention(多头自注意力)
- 1.2 应用场景
- 1.3 优点 and 缺点
- 优点
- 缺点
- 2. Cross-Attention(交叉注意力机制)
- 2.1 定义
- 2.2 工作原理
- 2.3 应用场景
- 2.4 优点 and 缺点
- 优点
- 缺点
- 3. Causal-Attention(因果注意力机制)
- 3.1 定义
- Predict The Next Word(预测下一个词)
- Masked Language Model(掩码语言模型)
- Autoregressive(自回归)
- Causal Attention(因果注意力)
- 3.2 工作原理
- 3.3 应用场景
- 3.4 优点 and 缺点
- 优点
- 缺点
- 4. 对比总结
- 5. 代码示例(PyTorch)
- 6. 总结
- 参考资料
Self-Attention、Cross-Attention 和 Causal-Attention 是深度学习中注意力机制的三种重要变体,广泛应用于 Transformer 模型等架构中。它们在功能、输入来源和应用场景上各有不同。
1. Self-Attention(自注意力机制)
1.1 定义
Self-Attention 是一种注意力机制,允许模型在处理一个输入序列时,关注序列内部的每个元素之间的关系。每个元素既作为查询(Query),又作为键(Key)和值(Value),通过计算自身与其他元素的相关性来更新表示
。
Scaled Dot-Product Attention(缩放点积注意力)
- 输入:单一序列 X X X,形状为 ( n , d ) (n, d) (n,d),其中 n n n 是序列长度, d d d 是嵌入维度。
- 计算:
- 生成查询、键、值: Q = X W q Q = X W_q Q=XWq, K = X W k K = X W_k K=XWk, V = X W v V = X W_v V=XWv。
- 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dkQKT。
- 应用 softmax: Attention Weights = softmax ( Score ) \text{Attention Weights} = \text{softmax}(\text{Score}) Attention Weights=softmax(Score)。
- 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention Weights⋅V。
体现如何计算注意力分数,关注Q、K、V计算公式。
Self Attention(自注意力)
对同一个序列
,通过缩放点积注意力计算注意力分数,最终对值向量进行加权求和,从而得到输入序列中每个位置的加权表示。
表达的是一种注意力机制,如何使用缩放点积注意力对同一个序列计算注意力分数,从而得到同一序列中每个位置的注意力权重。
Multi-Head Self Attention(多头自注意力)
多个注意力头并行运行
,每个头都会独立地计算注意力权重和输出,然后将所有头的输出拼接起来得到最终的输出。
强调的是一种实操方法,实际操作中我们并不会使用单个维度来执行单一的注意力函数,而是通过h=8个头分别计算,然后加权平均。这样为了避免单个计算的误差。
1.2 应用场景
- 自然语言处理:如 BERT 的编码器,用于理解句子中词与词之间的上下文关系(例如,捕捉“bank”在“river bank”和“bank account”中的不同含义)。
- 计算机视觉:如 Vision Transformer (ViT),用于建模图像中不同区域之间的关系。
- 序列建模:适用于需要全局上下文的任务,如文本分类、语义表示学习。
1.3 优点 and 缺点
优点
- 捕捉序列内部的长距离依赖关系。
- 并行计算,相比 RNN 更高效。
- 灵活性强,适用于多种任务。
缺点
- 计算复杂度为 O ( n 2 ) O(n^2) O(n2),对长序列计算成本高。
- 缺乏时间顺序约束,可能不适合生成任务。
2. Cross-Attention(交叉注意力机制)
2.1 定义
Cross-Attention 用于建模两个不同序列之间的关系
。一个序列提供查询(Query),另一个序列提供键(Key)和值(Value)。它通常用于需要融合来自不同数据源或模态的信息的任务。
2.2 工作原理
- 输入:
- 查询序列 X q X_q Xq,形状为 ( n , d q ) (n, d_q) (n,dq),通常来自目标序列。
- 键-值序列 X k v X_{kv} Xkv, 形状为 ( m , d k v ) (m, d_{kv}) (m,dkv),通常来自源序列。
- 计算:
- 生成查询、键、值: Q = X q W q Q = X_q W_q Q=XqWq, K = X k v W k K = X_{kv} W_k K=XkvWk, V = X k v W v V = X_{kv} W_v V=XkvWv。
- 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dkQKT。
- 应用 softmax: Attention Weights = softmax ( Score ) \text{Attention Weights} = \text{softmax}(\text{Score}) Attention Weights=softmax(Score)。
- 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention Weights⋅V。
- 特点:查询和键-值来自不同序列,输出反映了查询序列对源序列的关注。
- 多头机制:同样支持多头注意力以增强表达能力。
2.3 应用场景
- 机器翻译:在 Transformer 解码器中,查询来自目标语言序列,键-值来自源语言序列(如将“Je t’aime”翻译为“I love you”时对齐“aime”和“love”)。
- 视觉-语言模型:如 CLIP(对齐图像和文本特征)或 DALL·E(将文本描述融入图像生成)。
- 问答系统:从文档(键-值)中提取与问题(查询)相关的答案。
- 跨模态任务:如图像-文本检索、视频-文本对齐。
2.4 优点 and 缺点
优点
- 有效融合来自不同序列或模态的信息。
- 适合跨模态或跨语言任务。
- 支持并行计算,效率高。
缺点
- 计算复杂度为 O ( n ⋅ m ) O(n \cdot m) O(n⋅m),长序列时仍可能成本高。
- 依赖输入序列的质量,噪声可能影响对齐效果。
3. Causal-Attention(因果注意力机制)
3.1 定义
Causal-Attention(也称为 Masked Self-Attention)是自注意力的一种变体,通过引入掩码(Mask)限制模型只关注序列中当前及之前的元素,防止“看到未来”的信息
。它通常用于自回归生成任务,确保生成过程符合时间顺序。
Predict The Next Word(预测下一个词)
模型通常需要基于已经生成的词来预测下一个词。这种特性要求模型在预测时不能“看到”未来的信息,以避免预测受到未来信息的影响。
Masked Language Model(掩码语言模型)
遮盖一些词语来让模型学习预测被遮盖的词语,从而帮助模型学习语言规律。
Autoregressive(自回归)
在生成序列的某个词时,解码器会考虑已经生成的所有词,包括当前正在生成的这个词本身。为了保持自回归属性,即模型在生成序列时只能基于已经生成的信息进行预测,需要防止解码器中的信息向左流动
。换句话说,当解码器在生成第t个词时,它不应该看到未来(即第t+1, t+2,…等位置)的信息。
Causal Attention(因果注意力)
为了确保模型在生成序列时,只依赖于之前的输入信息,而不会受到未来信息的影响。Causal Attention通过掩盖(mask)未来的位置来实现这一点,使得模型在预测某个位置的输出时,只能看到该位置及其之前的输入。
3.2 工作原理
- 输入:单一序列 X X X,形状为 ( n , d ) (n, d) (n,d)。
- 计算:
- 生成查询、键、值: Q = X W q Q = X W_q Q=XWq, K = X W k K = X W_k K=XWk, V = X W v V = X W_v V=XWv。
- 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dkQKT。 (加掩码)。
- 掩码:在 softmax 之前,对分数矩阵的上三角(未来位置)施加负无穷大(或 0),确保每个位置只关注自身及之前的位置。
- 应用 softmax: Attention Weights = softmax ( Score with Mask ) \text{Attention Weights} = \text{softmax}(\text{Score with Mask}) Attention Weights=softmax(Score with Mask)。
- 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention Weights⋅V.
- 特点:通过掩码实现因果约束,适合自回归生成任务。
- 多头机制:同样支持多头注意力。
3.3 应用场景
- 语言生成:如 GPT 系列模型,用于生成连贯的文本(例如,生成下一个词时只依赖之前的词)。
- 机器翻译:在 Transformer 解码器中,用于确保生成目标序列时遵循时间顺序。
- 语音生成:如 WaveNet,用于生成音频序列。
3.4 优点 and 缺点
优点
- 适合自回归任务,保证生成过程的时间顺序。
- 保留自注意力的并行计算优势。
- 能捕捉序列中之前的上下文信息。
缺点
- 计算复杂度仍为 O ( n 2 ) O(n^2) O(n2)。
- 无法利用序列中的未来信息,限制了某些任务的性能。
4. 对比总结
特性 | Self-Attention | Cross-Attention | Causal-Attention |
---|---|---|---|
输入来源 | 单一序列(Q, K, V 均来自同一序列) | 两个序列(Q 来自目标,K, V 来自源) | 单一序列(带因果掩码) |
注意力范围 | 全局(所有位置) | 目标序列关注源序列 | 当前及之前位置 |
计算复杂度 | O ( n 2 ) O(n^2) O(n2), n n n 为序列长度 | O ( n ⋅ m ) O(n \cdot m) O(n⋅m), n , m n, m n,m为序列长度 | O ( n 2 ) O(n^2) O(n2), n n n 为序列长度 |
主要应用 | 语义表示(如 BERT)、图像分类 | 翻译、跨模态任务(如 CLIP) | 语言生成(如 GPT)、自回归任务 |
时间顺序约束 | 无 | 无 | 有(通过掩码实现) |
典型模型 | BERT, ViT Rosformer | Transformer (解码器), CLIP, DALL·E | GPT, Transformer (解码器部分) |
5. 代码示例(PyTorch)
以下是一个简单的 PyTorch 实现,展示三种注意力机制的区别:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Attention(nn.Module):def __init__(self, dim, num_heads):super(Attention, self).__init__()self.num_heads = num_headsself.head_dim = dim // num_headsself.query = nn.Linear(dim, dim)self.key = nn.Linear(dim, dim)self.value = nn.Linear(dim, dim)self.out = nn.Linear(dim, dim)def forward(self, query, key, value, mask=None, causal=False):batch_size = query.size(0)seq_len_q, seq_len_k = query.size(1), key.size(1)# 线性变换Q = self.query(query) # (batch_size, seq_len_q, dim)K = self.key(key) # (batch_size, seq_len_k, dim)V = self.value(value) # (batch_size, seq_len_k, dim)# 分割多头Q = Q.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(batch_size, seq_len_k, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(batch_size, seq_len_k, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力分数scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)# 因果掩码(仅 Causal-Attention 使用)if causal:mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()scores = scores.masked_fill(mask[None, None, :, :], float('-inf'))# 普通掩码(可选,用于 padding 或其他场景)if mask is not None:scores = scores.masked_fill(mask[None, None, :, :], float('-inf'))# Softmaxattn_weights = F.softmax(scores, dim=-1)# 加权求和out = torch.matmul(attn_weights, V)# 合并多头out = out.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)out = self.out(out)return out# 示例用法
dim = 512
num_heads = 8
batch_size = 32
seq_len_q, seq_len_k = 10, 20# 准备输入
query = torch.randn(batch_size, seq_len_q, dim)
key = torch.randn(batch_size, seq_len_k, dim)
value = torch.randn(batch_size, seq_len_k, dim)# 初始化注意力模块
attn = Attention(dim, num_heads)# Self-Attention
self_attn_output = attn(query, query, query)
print("Self-Attention Output Shape:", self_attn_output.shape) # (32, 10, 512)# Cross-Attention
cross_attn_output = attn(query, key, value)
print("Cross-Attention Output Shape:", cross_attn_output.shape) # (32, 10, 512)# Causal-Attention
causal_attn_output = attn(query, query, query, causal=True)
print("Causal-Attention Output Shape:", causal_attn_output.shape) # (32, 10, 512)
6. 总结
- Self-Attention 适合需要捕捉序列内部全局关系的任务,如语义表示学习。
- Cross-Attention 专为跨序列或跨模态任务设计,如翻译和多模态融合。
- Causal-Attention 适用于自回归生成任务,确保时间顺序约束。
选择哪种机制取决于任务需求:
- 如果需要全局上下文,使用 Self-Attention。
- 如果涉及两个序列的交互,使用 Cross-Attention。
- 如果需要生成序列并遵循时间顺序,使用 Causal-Attention。
参考资料
神经网络算法 - 一文搞懂Transformer中的三种注意力机制
第四篇:一文搞懂Transformer架构的三种注意力机制
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力