Transformer内容详解(通透版)
论文链接:[1706.03762] Attention Is All You Need
参考视频:Transformer论文逐段精读【论文精读】_哔哩哔哩_bilibili
代码推荐看这篇,把每个部分都单独提取了出来:(25 封私信) Transformer代码及解析(Pytorch) - 知乎
相关数学知识
这部分很重要!!!!论文中很多涉及到dk维度的计算,例如为什么注意力机制要除以根号dk?为什么embedding后要乘根号dk? 如果想要彻底搞懂,对正态分布需要理解。
正态分布
正态分布(Normal distribution),又叫高斯分布,是统计学和机器学习中最常用的概率分布之一。其概率密度函数为:
- μ:均值(mean),表示分布的中心。
- σ^2:方差(variance),表示分布的离散程度。
- σ:标准差(standard deviation),等于方差的平方根。
方差用来衡量一组数据的“分散程度”,标准差因为和原始量纲一致,更加直观。
方差大,数据分布“宽”,数值可能出现极大和极小,更层的输出如果方差越来越大,容易导致梯度爆炸;
方差小,数据分布“窄”,数值集中于均值附近,各层输出方差越来越小,容易导致梯度消失。
Xavier初始化
Xavier 初始化是一种广泛应用于深度神经网络中的参数初始化方法。
假设网络中的某一层,输入 x,输出 y,权重W--------> y=Wx;
设输入为 d_in维,输出为 d_out 维。
假设输入 x 的均值为 0,方差为 Var[x],权重 W 均值也是 0,方差为 Var[W]:
那么输出 y 的第 j 个元素:
由于都是独立同分布的,输出 y_j的方差为:
因为希望每层输入输出的方差一致,即,则有:
但这样只保证了前向传播的稳定。
反向传播类似分析,取d_out,希望梯度也不消失/爆炸,建议:
Xavier 初始化综合考虑了前向和反向传播,采用两者的平均:
为什么要提出Transformer?它解决了什么问题?
Transformer的提出主要是为了解决序列建模领域中传统方法(RNN,LSTM等)并行效率低、难以捕捉长距离依赖、模型扩展能力差的问题。
序列建模领域
常见的序列有:
- 文本(自然语言中的单词或字符序列)
- 语音(随时间变化的声音信号)
- 时间序列数据(如股票价格、传感器信号等)
- 基因序列(DNA碱基对序列)
- 视频帧序列
序列建模是指对序列本身的结构和规律进行刻画和预测,即学习序列中每个元素之间的相关性和依赖关系。
传统序列建模方法
可参看这篇博客:ELMo——Embeddings from Language Models原理速学-CSDN博客
RNN——循环神经网络
RNN具有“记忆”能力,在处理当前输入的同时会考虑历史信息。
hₜ = f(Wₓxₜ + Wₕhₜ₋₁ + b) yₜ = g(Wₒhₜ + c)
- 输入序列:x₁, x₂, ... xₜ,表示当前时刻的输入
- 隐藏状态:h,hₜ₋₁ 表示上一个时刻的隐藏状态
- 输出:yₜ
- Wₓ、Wₕ、Wₒ和b、c是需要学习的参数
其中,f和g一般为非线性激活函数(如tanh、softmax)。
RNN在训练时反向传播过程中,因为误差梯度需要逐步从序列末端向前传播到最开始。在长序列中,多个时刻的梯度连乘会导致梯度爆炸(梯度>1)或者梯度消失(梯度<1)。
梯度计算中包含激活函数的倒数以及权重矩阵W。
多数情况梯度更容易快速衰减几乎消失,导致网络几乎无法感知早期的输入,训练困难;
并且RNN常用的激活函数(tanh和sigmoid),倒数最大值都小于1,加剧了梯度消失现象。
总结就是,由于梯度消失,RNN对于距离很远的信息,梯度几乎传递不到,所以模型记不住早期的信息。
LSTM——长短时记忆网络
LSTM是RNN的一种变体,专门为了解决RNN在长序列训练中出现的梯度消失问题,能够更好捕捉序列中长期的依赖关系。
LSTM引入了“记忆单元cell”和三个门控结构(输入门、遗忘门、输出门):
遗忘门(Forget Gate):决定哪些信息需要丢弃
输入门(Input Gate):决定哪些新信息需要加入记忆单元
输出门(Output Gate):决定哪些信息影响当前输出和下一个隐状态。
fₜ = σ(W_f·[hₜ₋₁, xₜ] + b_f) // 遗忘门 iₜ = σ(W_i·[hₜ₋₁, xₜ] + b_i) // 输入门 oₜ = σ(W_o·[hₜ₋₁, xₜ] + b_o) // 输出门 c̃ₜ = tanh(W_c·[hₜ₋₁, xₜ] + b_c) // 新的候选记忆 cₜ = fₜ * cₜ₋₁ + iₜ * c̃ₜ // 新的记忆状态 hₜ = oₜ * tanh(cₜ) // 新的隐藏状态
LSTM通过门控机制,有选择地保留/丢弃信息,使得重要信息在“记忆通道”中可以长距离无损传递,反向传播时梯度也能顺畅地流动,不会很快消失。
为什么传统序列建模方法无法并行计算?
因为像RNN,LSTM这种传统序列建模方法的本质是依赖之前时刻的隐藏状态和当前的输入,导致有严格的依赖链,每一步都必须先计算上一步,无法提前获取。
所以,无论是前向传播(推理)还是反向传播(训练),都只能一个时刻一个时刻地处理。
不能像卷积网络那样,把所有像素“同时”喂给网络,因为序列模型每一步的输入不仅是当前的x,还必须等前面状态的输出结果。
CNN处理时序建模任务又有什么缺点?
前面提到RNN无法并行计算,那么CNN是可以并行计算的,为什么CNN不适合做长序列任务?
CNN的特性
局部感受野:CNN的卷积窗口只能看到输入序列的一个“窗口”内容,即只关注局部信息;
弱顺序性:虽然能通过多层堆叠扩大感受野,但本质是并行处理各个片段,对长距离依赖的建模能力有限。
计算量和梯度:普通CNN只能通过堆叠多个层才能覆盖长距离信息,这样会导致深度过大代码计算量和梯度问题。
Transformer结构
Transformer整体是Encoder-Decoder结构,是一个自回归模型。
什么是自回归模型?
自回归模型Autoregressive Model 是常用的一种时间序列建模方法,核心思想是:当前时刻的值可以用前面若干时刻的值线性组合来预测。
编码器Encoder
Encoder使用的是六个完全一样的层,每一层有两个子层,每一个子层用残差连接。
第一个子层是多头注意力机制(Multi-Head Attention);
第二个子层是前馈神经网络MLP。
每一个子层都可以表示为,公式中的sublayer可以是MLP或者多头注意力:
因为使用了残差连接,输入x 和输出sublayer(x)的维度必须相等,所以直接设置了hidden_dim把每一层的输出维度都固定(512)。
残差连接的作用
缓解深层网络退化问题:随着深层网络加深,训练更加困难,损失可能不降反增,通过引入shotcut跨层传递,极大减缓深层网络退化问题;减轻梯度消失问题,利于梯度传播;
学习到恒等映射:即如果后面几层没有益处,残差为0;
提升模型表现和泛化能力:大量实验和论文证明,加入残差连接后,不仅能训练得更深,还获得了更好的准确率和泛化性能。
为什么要归一化?
归一化的目的是使得数据被限制在一定范围内,从而消除奇异样本导致的不良影响。
缓解梯度消失与爆炸:归一化让数据保持在合适的范围内,稳定梯度(防止梯度变得特别大或者特别小),让网络容易训练,训练过程更稳定;
避免某些特征“主导”网络:归一化让每个特征在初期阶段都有相似的影响力(防止某些输入的特征过大,其他特征小,模型容易过度依赖大数值特征),提高模型泛化能力;
加速收敛:归一化让不同特征分布接近一直,优化器能使用更合适的学习率,从而提高训练速度,减少训练轮次。
LayerNorm vs BatchNorm
LayerNorm——层归一化
LayerNorm是对每个样本自身的特征维度进行归一化(也就是对一个样本的所有神经元输出做归一化)
- 均值计算:
- 方差计算:
- 归一化:
(
是一个很小的常数,防止分母为0)
- 缩放偏移:
(
和
是可学习参数,维度和H一致)
BatchNorm——批归一化
在每一层对小批量batch数据的统一特征通道做归一化。
对同一个通道c,在batch和空间维上归一化:
为什么Transformer归一化使用LayerNorm?
因为Transformer用来处理序列数据,输入的每个时间步token维度通常不大(512);
采用按token的LayerNorm可以更好地适应不同长度、不同批次的输出,且不依赖batch的分布。
特点 LayerNorm(层归一化) BatchNorm(批归一化) 归一化范围 对每个样本的特征维度归一化 对每一维度在整个batch上归一化 公式 对每个样本,沿feature维度 对每个feature,沿batch维度 对 batch size 敏感 不敏感 敏感(batch太小效果不好) 适用场景 RNN、Transformer、NLP等 CV/CNN,大batch size时效果好 训练/推理一致性 训练和推理流程一致 推理用running mean/var,略有差异 对序列长度 可变长度也能用 一般固定长度或padding处理 常见用途 NLP、序列建模、Transformer等 图像分类、卷积网络等
解码器Decoder
同样是六个完全一样的层组成,每一层除了和Encoder一样的Multi-Head Attn + Feed Forward外,中间还多了一个Multi-head Attn。
每一个子层中仍然使用残差链接和LayerNorm。
Decoder部分因为做的是自回归训练,用之前时刻的信息预测后面的内容;也就是说模型在处理序列数据的时候,只能关注当前位置以及之前的元素,不能看到位置之后的元素,不然就作弊了,所以这里使用了掩码机制。
掩码机制
Transformer中有两种掩码,一种是Padding Mask(填充掩码),另一种是Sequence Mask(序列掩码)。
Padding Mask
处理输入的序列,每一个句子可能长度不一样,所以这里补齐到一致的长度方便统一输入;加入Padding Mask是告诉模型哪些位置是不用计算的。
Sequence Mask
序列掩码主要是用在Decoder中第一个多头注意力的输入部分,因为做自回归训练,需要讲当前元素后面的元素遮住再预测。
这部分很容易误理解为下面的情况,以为每一个句子都对应独立的sequence mask!!但其实都是上面统一的sequence mask:
注意力Attention
注意力机制可以描述为:将一个查询Quary 和 一组“Key-Value”键值对映射为一个输出,其中Q,K,V和输出都是向量。
输出是通过对Value加权求和得到,其中每个Value的权重是由Query和对应的Key的相似函数计算得出。
注意力机制是作用在句子之间的!!!也就是一个句子中的单词之间的相关性,而不是句子与句子之间的相关性。一个batch虽然后多条语句,但是他们相互独立,放在一起只是为了并行运算提高计算效率。
不同的相似函数导致不同的注意力机制版本,Transformer使用的是点乘注意力机制。
【1】输入由 维度为d_k的query 和 key + 维度为d_v的value 组成;
【2】先计算query 和 key 的点积,求相似度,得到注意力分数矩阵;
【3】再做scale操作,即➗;
【4】可选Mask操作,再自回归阶段,屏蔽掉不能看的信息;
【5】对每一行也就是每个query的分数做softmax归一化操作,得到一组概率(权重),所有的key对应的权重和为1;
【6】将得到的权重与Value计算得到最终的输出,这个输出就表示从Value中提取出来的最重要的信息。
计算相似度为什么直接点积而不使用余弦相似度?
余弦相似度
余弦相似度可以理解为计算两个向量在空间上,方向的一致性。
- 分子是点积
- 分母是模长的乘积(归一化),屏蔽了向量长度信息
余弦相似度取值范围是[−1,1],1表示同向,0表示正交,-1表示反向。
直接点积
- 保留了长度信息,不需要计算两个向量的模长,减少了计算量;
- 后面添加了softmax归一化操作[0,1]范围内,模长范围[0,d],用
代替模长积;
为什么要➗
?
Q,K的维度为[seq_len,d_model];
提取q,k也就是Q,K的第i个token,一个d_model维的向量。
假设q,k的原始分布式均值为0,方差为1的随机变量,那么根据方差的性质:
也就是说,qk方差会随着维度d的增大而增大,这会导致注意力分数矩阵S中的元素值变得非常大,那么softmax(s)分布就会很靠近1或者0,容易出现梯度消失。
为了稳定结果,需要让qk➗
那么:
原始的点乘注意力机制其实没有除法的运算。
如果没有scale操作,假设向量维度d_k增大,QK^T的值可能变得非常大,经过softmax函数时,可能会导致梯度消失。
因为softmax对输入数值范围敏感,输入值过大或过小会导致softmax输出趋于1或0,梯度变得非常小,难以有效传播。!!!!!!!!!!!!!!!!!!!!!!!!
import torch
import torch.nn as nnclass Self_Attention(nn.Module):def __init__(self, dim, dk, dv):super().__init__()self.scale = dk ** -0.5self.q = nn.Linear(dim, dk)self.k = nn.Linear(dim, dk)self.v = nn.Linear(dim, dv)def forward(self, x):# x: [batch, seq_len, dim]q = self.q(x) # [batch, seq_len, dk]k = self.k(x) # [batch, seq_len, dk]v = self.v(x) # [batch, seq_len, dv]attn = (q @ k.transpose(-2, -1)) * self.scale # [batch, seq_len, seq_len]attn = attn.softmax(dim=-1)out = attn @ v # [batch, seq_len, dv]return out
多头注意力机制
import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, dim, dk, dv, num_heads):super().__init__()self.num_heads = num_headsself.dk = dkself.dv = dvself.q_linear = nn.Linear(dim, dk * num_heads)self.k_linear = nn.Linear(dim, dk * num_heads)self.v_linear = nn.Linear(dim, dv * num_heads)self.out_linear = nn.Linear(dv * num_heads, dim)def forward(self, x):B, N, _ = x.shape # batch, seq_len, dimQ = self.q_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2) # [B, heads, N, dk]K = self.k_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)V = self.v_linear(x).view(B, N, self.num_heads, self.dv).transpose(1, 2)# Attentionattn = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5) # [B, heads, N, N]attn = attn.softmax(dim=-1)out = attn @ V # [B, heads, N, dv]out = out.transpose(1, 2).reshape(B, N, self.num_heads * self.dv) # [B, N, heads*dv]out = self.out_linear(out) # [B, N, dim]return out
为什么要使用多头注意力?
因为原本的点乘注意力机制(左图)中没有什么可学习参数,就是点乘计算;
为了增加一些参数,也就是想多学习一些东西,所以分为了多个头(右图),linear操作,高维映射到多个低维,这个过程需要学习一些参数。也就是给你n_head次机会,希望学到不一样的投影方法,使得每一种投影方法匹配不同的模式。
position-wise feed-forward networks
这里的Feed Forward其实就是一个MLP多层感知机(两个linear线形变换),中间带有一个ReLU激活函数。
- xW1+b1-----> 线性层;
- max------>relu激活层;
- maxW2+b2---->再一个线性层;
为什么要用FFN?
在Transformer中,Attention的作用就是把整个序列里面的信息提取出来
x对应的维度通常是512,W1将512投影到2048,W2又再次投影回512。
FFN过程就是维度先扩大再缩小,先展开看细节,再抓重点。增加非线性关系,学习更多的潜在表达。
ReLU激活函数
如果 x>0,则输出是 x
如果 x≤0,则输出是 0
但注意,这里的pisition-wise,意思是对输入序列中的每一个位置(每一个token),都单独的使用一个FFN处理。也就是说FFN只作用在最后一个维度。
position-wise举例
假设序列有token A、B、C,输入分别是向量x1,x2,x3。
那么FFN的处理是:
- FFN(x1)
- FFN(x2)
- FFN(x3)
三者互不干扰,但参数共享,使用的是同一套参数(即同一个FFN的权重和偏置)
其中 W1,W2,b1,b2 就是参数。
不管是x1、x2、还是x3,它们都被送入同一个FFN(同一组W1,W2,b1,b2参数),各自独立运算,互不影响。
为什么要使用position-wise 方法?
Attention提取整个序列的信息,也就是汇聚加权(全局);
经过MLP做映射,因为Attention的输出token已经包含了整体的序列信息,所以只需要分开对每一个单独的token做映射到想要的语义空间就可以了。
Embedding
可看:Qwen2.5-vl源码解读系列:LLM的Embedding层-CSDN博客
文本embedding可以理解文将一个token用一个向量表示,从离散对象(单词)抽象为一个连续的向量空间。
Transformer中有两个Embedding + 一个逆向Embedding,这三个共享参数。
Input Embedding
首先在模型开始训练之前,会有一个词汇表,汇总了模型训练的能够传入的单词或词根,每一个单词对应唯一的token_id,然后使用Embedding将每一个单词都映射为一个向量。假设词汇表中有15374个单词,每一个单词用768维度的向量表示。得到一个大型矩阵,也就是查找表。
根据输入的单词到词汇表中查找token_id,也就找到了对应的嵌入向量。如果输入单词是6个,那么经过Embedding,输入Encoder的尺寸就是[6,768]。
OutPut Embedding
和Input Embedding同理,但注意Decoder部分输入加了Sequence Mask.
Linear
Decoder的输出是[6,768],也就是6个位置的向量信息。需要解码得到的抽象表示转换为我们可以理解的单词。
通过linear把768维向量投影到词汇表大小的向量(例如15374),这样就可以使用softmax把这个向量变成一组概率分布,也就是词汇表中每一个单词的概率。
因为单词<---->向量的映射关系唯一,所以这三个地方共用参数。
注意!在这三个embedding过后,都需要将向量再✖
为什么要乘
?
稳定数值,为了让 embedding 层的输出的数值幅度,与后续 self-attention 机制的缩放一致,避免数值过大或过小影响训练稳定性。
Embedding本质是一个大型矩阵,nn.embedding默认使用的是均匀分布初始化,区间是
,这其实是Xavier初始化的简化版(省略了输入、输出维度加和那一步)。
所以再乘上
使得分布区间回到[-1,1]。
Position Encoding
Transformer中使用的是绝对位置编码中的正余弦位置编码。
用一组不同频率的正弦和余弦波,对每个位置编码成一个dmodeldmodel维向量,把顺序信息注入到Transformer输入中,保证模型具备序列建模能力且无须训练。
- pos:序列中的位置
- i:维度索引(i=0,1,...,d_model/2−1)
- 2i和 2i+1 分别为偶数和奇数下标。
正余弦编码通俗理解:
i越大,sin/cos的周期越长,变化越慢,编码变化越平滑,对远距离元素越不敏感。
sin/cos把数值规范在[-1, 1]的区间,避免数值爆炸,同时方便模型捕捉顺序和相对距离。
为什么要使用位置编码?
再Attention中,其实就是计算每个向量和序列中其他向量之间的关系,本身对序列的顺序不敏感。但从语义层面,单词前后顺序的不同可能语义完全不同,所以需要加入位置编码。
为什么使用自注意力?
如下图,作者做了实验,比较了四种层:自注意力、循环层、卷积层、受限卷积层。
用三种参数做比较:计算复杂度、顺序的计算(越少越好)、信息从一个数据点走到另一个数据点要走多远;
结果如下: