transformer详解(位置编码+attention+残差连接+全连接网络)
2017年谷歌发表的transformer is all you need论文中, transformer网络架构如下图所示. 其中主要包括位置编码(position embedding), 注意力机制(attention), 残差连接&层归一化(add & norm), 全连接神经网络等模块(feed forward), 接下来对这几个模块展开分析.
残差连接&层归一化
Output=LayerNorm(x+SubLayer(x))\text{Output} = \text{LayerNorm}(x + \text{SubLayer(x)}) Output=LayerNorm(x+SubLayer(x))
残差连接x+SubLayer(x)x + \text{SubLayer(x)}x+SubLayer(x), 防止深层网络训练时梯爆炸
层归一化LayerNorm(⋅)=xˉi=xi−μσ2+ϵ\text{LayerNorm}(\cdot) = \bar{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2} + \epsilon}LayerNorm(⋅)=xˉi=σ2+ϵxi−μ 保证数值范围稳定, 缓解训练不稳定
前馈全连接网络
FFN(x)=σ(xW1+b1)W2+b2FFN(x) = \sigma(xW_1 + b_1) W_2 + b_2 FFN(x)=σ(xW1+b1)W2+b2
引入σ(⋅)\sigma(\cdot)σ(⋅)非线性变换增强模型表达能力
注意力机制(attention)
公式
单头自注意力机制公式:
Attention(Q,K,V)=softmax(QKTdK)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_K}})V Attention(Q,K,V)=softmax(dKQKT)V
多头自注意力机制公式:
headi=Attention(Qi,Ki,Vi)MultiHead(Q,K,V)=Concact(head1,...headn)head_i = \text{Attention}(Q_i, K_i, V_i) \\ \text{MultiHead}(Q, K, V) = \text{Concact}(head_1, ... head_n) headi=Attention(Qi,Ki,Vi)MultiHead(Q,K,V)=Concact(head1,...headn)
自注意力机制解释: Q和K参数根据上下文token的语义向量查询各token相互之间的关联系数并用softmax归一化至0~1内, 可认为各token之间的关联权重, softmax(QKTdK)\text{softmax}(\frac{QK^T}{\sqrt{d_K}})softmax(dKQKT). 用V参数根据关联权重系数更新每个token在当前上下文对应的语义向量, 相当于把V参数根据权重分配给各token的语义向量.
多头自注意力机制: 把单头的特征维度平均拆分为n个头, 相当于n个专家, 各自关注不同维度的特征, 形象解释就是有的注意力头关注动词关系, 有的注意力头关注名词关系.
缩放因子: 通过dK\sqrt{d_K}dK控制点积QKTQK^TQKT的均值为0, 方差为1, 避免softmax进入梯度饱和区间, 稳定训练过程, 具体数学推导可以参考这篇博客.
掩码
填充掩码和特殊掩码都好理解, attention的掩码主要指因果注意力掩码, 是为了训练时能够高效并行训练而引入的, 举例:
样本: “今天天气晴朗”, 对应token序列: [今天, 天气, 晴朗, EOS]
在训练时, 该样本可以拆分成以下三条子样本进行训练:
- 输入: [今天], 输出: [天气, 晴朗, EOS]
- 输入: [今天, 天气], 输出: [晴朗, EOS]
- 输入: [今天, 天气, 晴朗], 输出: [EOS]
如果没有causal mask, 三条子样本需要分别输入到模型中进行训练, 例如第一次输入[今天], 以[天气, 晴朗, EOS]为label, 计算loss并反向传播. 第二次输入[今天, 天气], 第三次输入[今天, 天气, 晴朗].
而引入causal mask后, 可以用mask遮住每条子样本未来的信息, 把所有子样本组成一个矩阵输入模型并行进行训练, 提高训练效率. mask后的token矩阵如下所示
[今天, 遮住, 遮住, 遮住]
[今天, 天气, 遮住, 遮住]
[今天, 天气, 很好, 遮住]
[今天, 天气, 很好, EOS]
这样的话训练时可以输入样本的整个序列并行训练, 且当前token无法关注到未来的信息, 只能关注到当前和过去的信息, 与推理时的情况一致, 只能通过已有的token生成新的token.
单头自注意力机制手撕
import torch
from torch import nnclass Attention(nn.Module):def __init__(self, embed_size):super().__init__()self.q_proj = nn.Linear(embed_size, embed_size, bias=True)self.k_proj = nn.Linear(embed_size, embed_size, bias=True)self.v_proj = nn.Linear(embed_size, embed_size, bias=True)self.softmax = nn.Softmax(dim=-1)self.dk_sqrt = torch.sqrt(torch.tensor(embed_size, dtype=torch.float32))def forward(x: torch.Tensor, use_causal_mask: bool=False):xq = self.q_proj(x)xk = self.k_proj(x)xv = self.v_proj(x)# x shape: batch_size * seq_len * embed_size# softmax(xq xk^T / \sqrt(dk)) xvscores = torch.matmul(xq, xk.transpose(-1, -2)) / self.dk_sqrtif use_causal_mask:seq_len = x.shape[1]causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()scores = scores.masked_fill(causal_mask, float('-inf'))weights = self.softmax(scores)output = torch.matmul(weights, xv)return output
多头自注意力机制手撕
import torch
form torch import nndef MultiheadAttention(nn.Module):def __init__(self, embed_size: int, n_head: int):self.q_proj = nn.Linear(embed_size, embed_size)self.k_proj = nn.Linear(embed_size, embed_size)self.v_proj = nn.Linear(embed_size, embed_size)self.o_proj = nn.Linear(embed_size, embed_size)self.softmax = nn.Softmax(-1)self.n_head = n_headself.hidden_dim_per_head = embed_size // n_headself.dk_sqrt = torch.sqrt(torch.tensor(self.hidden_dim_per_head, dtype=torch.float32))def forward(self, x: torch.Tensor, use_causal_mask: bool=False):# x shape: batch_size * seq_len * embed_sizebatch_size = x.shape[0]seq_len = x.shape[1]embed_size = x.shape[2]xq = self.q_proj(x)xk = self.k_proj(x)xv = self.v_proj(x)# xq xk xv shape: batch_size * seq_len * n_head * hidden_dim_per_head -> batch_size * n_head * seq_len * hidden_dim_per_headxq = xq.view(batch_size, seq_len, self.n_head, self.hidden_dim_per_head).transpose(1, 2)xk = xk.view(batch_size, seq_len, self.n_head, self.hidden_dim_per_head).transpose(1, 2)xv = xv.view(batch_size, seq_len, self.n_head, self.hidden_dim_per_head).transpose(1, 2)# softmax((xq xk^T)/\sqrt(dk)) xv# scores shape: batch_size * n_head * seq_len * seq_lenscores = torch.matmul(xq, xk.transpose(-1, -2)) / self.dk_sqrtif use_causal_mask:causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()scores = scores.masked_fill(causal_mask, float('-inf'))# weights shape: batch_size * n_head * seq_len * seq_lenweights = self.softmax(scores, -1)# context shape: batch_size * n_head * seq_len * hidden_dim_per_head -> batch_size * seq_len * n_head * hidden_dim_per_headcontext = torch.matmul(weights, xv).transpose(1, 2)# context shape: batch_size * seq_len * n_head * hidden_dim_per_head -> batch_size * seq_len * embed_sizecontext = context.view(batch_size, seq_len, embed_size)output = self.o_proj(context)return output
位置编码
位置编码是为了解决transformer的置换不变性, 将输入打乱, 输出仅变换顺序, 导致无法捕捉语句中不同位置代表的不同信息. 位置编码有很多种, 目前最常用的是旋转位置编码(RoPE)
RoPE的公式为:
⟨qmRoPE,knRoPE⟩=qmR((m−n)θ)knTθ={θi=10000−2(i−1)/d,i∈[1,2,...,d/2]}⟨q_m^{RoPE}, k_n^{RoPE}⟩ = q_m R((m-n)\theta) k_n^T \\ \theta = \left\{\theta_i=10000^{-2(i-1)/d}, i\in [1, 2, ..., d/2]\right\} ⟨qmRoPE,knRoPE⟩=qmR((m−n)θ)knTθ={θi=10000−2(i−1)/d,i∈[1,2,...,d/2]}
式中qmq_mqm表示第m个位置token的query向量, knk_nkn表示第n个位置token的key向量, R(⋅)R(\cdot)R(⋅)表示旋转矩阵, θ\thetaθ表示一组旋转频率, 与语义向量中的元素维度iii有关.
关于RoPE网上解释很多, 一些细节问题不再过多描述, 可以参考网上其它的资料. 很多资料看下来有两个问题解得为不是很清楚: 1.为什么对于某个token的语义向量, 用不同频率θ\thetaθ进行旋转; 2.为什么低维用大θ\thetaθ, 高维用小θ\thetaθ, 不能反一下呢.
-
为什么对于某个token的语义向量, 不同维度用不同频率进行旋转
形象解释, 可以把一个token的语义向量想象成一组齿轮, 用多个不同转速的齿轮组合来编码位置, 即使某个齿轮转完了一圈, 其它位置的齿轮仍然不同, 能够唯一确定整体位置, 防止重复. 如果一个语义向量按照固定转速旋转, 由于旋转的周期性, 序列足够长的话总会与后续的token位置编码重复.
-
为什么低维用大θ\thetaθ, 高维用小θ\thetaθ
首先θ\thetaθ越大相当于旋转越多, 位置信息区分越明显. 对于语义向量来说, 低维表征了最容易被区分的语义特征(重要性高), 高维表征了更多细腻和抽象的信息(重要性低).
位置编码的目的是期望相邻词能够通过强信号来区分不同语法角色(包括语义和位置), 所以对于低维特征需要大θ\thetaθ旋转, 一方面区分语义另一方面区分位置. 高维特征用小θ\thetaθ一方面保证了位置编码的唯一性, 另一方面使得抽象的语义特征在长尺度上是连贯的.
如果高维的稀疏特征用大θ\thetaθ旋转, 会弱化位置信息导致无法区分. 这样设计就像汽车上面油门踏板很小, 而喇叭很大, 弱化重要功能而放大次要功能, 系统无法协调工作.
手撕
# 整体思路为计算旋转变换后的最终张量输出到表格中, 每个元素为xcos\theta + ysin\theta的形式
# 1. 计算旋转变换的角度
# 2. 生成旋转变换对应的cos张量和sin张量, 大小分别为batch_size * seq_len * embed_size
# 3. 生成带旋转矩阵方向的张量, 用于乘旋转矩阵的第二列, 旋转矩阵为:
# [cos, -sin]
# [sin, cos]
# 4. 输入张量与cos sin张量对应位置元素相乘得到rope后的旋转张量xq_rope和xk_rope
def get_rope_sin_cos(batch_size: int, seq_len: int, embed_size: int):freq = 1.0 / (10000 ** (torch.arange(0, embed_size, 2)) / embed_size)ids = torch.arange(0, seq_len)# R(ids theta)freqs = torch.einsum("i,j->ij", ids, freq) # 得到ids行, freq列的矩阵, 对应元素相乘emb = torch.repeat_interleave(freqs, 2) # 交错堆叠return emb.cos(), emb.sin()def apply_rope(x: torch.Tensor):batch_size = x.shape[0]seq_len = x.shape[1]embed_size = x.shape[2]x1 = x[..., 0::2] # 偶数维度x2 = x[..., 1::2] # 奇数维度x_rot = torch.stack([-x2, x1], dim=-1).reshape_as(x) # 旋转90度cos, sin = get_rope_sin_cos(batch_size=batch_size, seq_len=seq_len, embed_size=embed_size)cos = cos.unsqueeze(0) # [1, seq, dim]sin = sin.unsqueeze(0)return x * cos + x_rot * sin # 矩阵与cos及sin表的对应位置元素相乘