当前位置: 首页 > news >正文

【大语言模型】—— 自注意力机制及其变体(交叉注意力、因果注意力、多头注意力)的代码实现

【大语言模型】—— 注意力机制及其变体的代码实现

  • 摘要
  • Self-Attention
    • 为什么 dim=-1
  • CrossAttention
    • 交叉注意力的作用
  • CausalAttention
    • unsqueeze(0)的作用​​
  • MultiHeadSelfAttention
    • 多头自注意力(Multi-Head Attention)的核心思想

摘要

本文介绍了注意力机制的几种变体及其PyTorch代码实现。主要包括:

  1. Self-Attention:基础自注意力机制,通过Q、K、V计算注意力权重,适用于序列内部建模。
  2. CrossAttention:让一个序列关注另一个序列,典型应用于Transformer解码器-编码器交互和多模态任务。
  3. CausalAttention:通过三角掩码实现因果性,确保只能关注当前位置之前的token,适用于自回归生成任务。
  4. MultiHeadSelfAttention:多头注意力机制,将输入分割到多个子空间并行计算注意力,最后合并结果。

代码实现中详细展示了各注意力的关键操作,包括线性变换、注意力分数计算、softmax归一化和掩码处理等。特别解释了dim=-1的作用、unsqueeze(0)的广播机制等实现细节。

Self-Attention


class SelfAttention(nn.module):
# torch.matmul     专用于批量矩阵乘法,适用于形状为 (batch_size, n, m) 和 (batch_size, m, p) 的 3D 张量。
# torch.matmul  支持更灵活的张量乘法运算def __init__(self, input_dim, dim_k,dim_v):super().__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)selv.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x):Q = self.q(x)K = self.k(x)V = self.v(x)atten = torch.softmax(torch.matmul(Q, K.permute(0,2,1))/self.scale, dim=-1)return torch.matmul(atten, V)

为什么 dim=-1

在自注意力机制中,nn.Softmax(dim=-1)的作用是对 ​​注意力分数矩阵​​ 进行归一化,使得每一行的权重之和为 1。这里 dim=-1表示在最后一个维度(即 seq_len维度)上进行 Softmax 计算。

Q K T QK^T QKT计算后,得到的注意力分数矩阵的形状是 [batch_size, seq_len, seq_len],其中:
​​第 1 个 seq_len(dim=1)​​:代表 Q的序列长度(即当前 token 的位置)。
​​第 2 个 seq_len(dim=2)​​:代表 K的序列长度(即被计算注意力的 token 的位置)。
dim=-1(即 dim=2)表示 ​​对每个 token 计算它对所有 token 的注意力权重​​,即 ​​对每一行进行 Softmax​​,使得:
每一行的所有值加起来等于 1(概率分布)。
这样,每个 token 的注意力权重是独立计算的。
假设 Q K T QK^T QKT的结果是:

[[[1.0, 0.5, 0.2],  # Token 0 对所有 token 的注意力分数[0.3, 1.2, 0.7],  # Token 1 对所有 token 的注意力分数[0.1, 0.4, 1.5]]  # Token 2 对所有 token 的注意力分数
]

应用 nn.Softmax(dim=-1)后:

[[[0.55, 0.27, 0.18],  # Token 0 的注意力权重(总和=1)[0.16, 0.58, 0.26],  # Token 1 的注意力权重(总和=1)[0.07, 0.20, 0.73]]  # Token 2 的注意力权重(总和=1)
]

这样,每个 token 的注意力权重是独立的,且所有 token 对它的影响权重之和为 1。

CrossAttention

# 查询通常来自解码器,键和值通常来自编码器
import torch
import torch.nn as nn
import numpy as npclass CrossAttention(nn.Module):def __init__(self, input_dim, dim_k, dim_v):super().__init__()  # 必须调用父类初始化self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x1, x2):Q1 = self.q(x1)  # [batch_size, seq_len1, dim_k]K2 = self.k(x2)  # [batch_size, seq_len2, dim_k]V2 = self.v(x2)  # [batch_size, seq_len2, dim_v]# 计算注意力分数atten = torch.softmax(torch.matmul(Q1, K2.permute(0, 2, 1)) / self.scale, dim=-1)  # [batch_size, seq_len1, seq_len2]# 加权求和return torch.matmul(atten, V2)  # [batch_size, seq_len1, dim_v]

交叉注意力的作用

交叉注意力用于 让一个序列 x 1 x1 x1关注另一个序列 x 2 x2 x2,典型应用包括:

  1. ​​Transformer 解码器​​:
    • x1= 解码器的输入(当前生成的 token)
    • x2= 编码器的输出(源序列的表示)
    • 解码器通过交叉注意力关注编码器的信息。
  2. ​​多模态任务​​(如视觉-语言模型):
    • x1= 文本序列
    • x2= 图像特征
    • 文本通过交叉注意力关注图像的关键区域。

CausalAttention

class CausalAttention(nn.Module):def __init__(self,input_dim, dim_k,dim_v):super().__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x):# x: [batch_size, seq_len, input_dim]Q = self.q(x)  # [batch_size, seq_len, dim_k]K = self.k(x)  # [batch_size, seq_len, dim_k]V = self.v(x)  # [batch_size, seq_len, dim_v]# 注意力分数atten = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale  # [batch, seq, seq]# 下三角 mask,确保因果性(只能看到之前的token)seq_len = atten.size(-1)mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)atten = atten.masked_fill(mask == 0, float('-inf'))# softmax 归一化atten = nn.Softmax(dim=-1)(atten)# 输出加权和return torch.matmul(atten, V)  # [batch, seq, dim_v]

unsqueeze(0)的作用​​

假设我们有一个 2D 张量 mask,形状是 [seq_len, seq_len]

mask = torch.tril(torch.ones(seq_len, seq_len))  # 形状 [seq_len, seq_len]

如果我们想让它变成 [1, seq_len, seq_len](即增加一个 batch 维度),可以使用:

mask = mask.unsqueeze(0)  # 形状变为 [1, seq_len, seq_len]

这样做的目的是:

  • 匹配注意力分数矩阵的形状​​(atten的形状是 [batch_size, seq_len, seq_len])。
  • ​​支持批量计算​​,因为 mask需要广播到所有 batch 样本。

这样:
mask的形状变成 [1, seq_len, seq_len]
PyTorch 会自动广播 mask到 [batch_size, seq_len, seq_len],使其与 atten形状匹配。

MultiHeadSelfAttention

class MultiHeadAttention(nn.Module):def __init__(self, heads, input_dim, dim_k, dim_v):super().__init__()self.heads = headsself.dim_k_per_head = dim_k // headsself.dim_v_per_head = dim_v // headsself.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(self.dim_k_per_head)self.out = nn.Linear(dim_v, input_dim)def forward(self, x):batch_size = x.size(0)Q = self.q(x)#[batch_size, seq_len, dim_k]K = self.k(x)V = self.v(x)#[batch_size, seq_len, heads, dim_k_per_head] #		--> [batch_size, heads, seq_len, dim_k_per_head]Q = Q.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)#[batch_size, seq_len, heads, dim_k_per_head] #		--> [batch_size, heads, seq_len, dim_k_per_head]K = K.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)#[batch_size, seq_len, heads, dim_v_per_head] #		--> [batch_size, heads, seq_len, dim_v_per_head]V = V.view(batch_size, -1, self.heads, self.dim_v_per_head).permute(0,2,1,3)#转置[batch_size, heads, seq_len, dim_k_per_head] #		--> [batch_size, heads, dim_k_per_head, seq_len]K = K.permute(0, 1, 3, 2)# [batch_size, heads, seq_len, seq_len]atten = torch.softmax(torch.matmul(Q,K) / self.scale, dim = -1)# [batch_size, heads, seq_len, dim_v_per_head]out = torch.matmul(atten, V)# [batch_size, seq_len, heads, dim_v_per_head]out = out.permute(0, 2, 1, 3).contiguous()# [batch_size, seq_len, heads* dim_v_per_head]out = out.view(batch_size, -1, self.heads * self.dim_v_per_head)return self.out(out) # [batch_size, seq_len, input_dim]

多头自注意力(Multi-Head Attention)的核心思想

多头自注意力(Multi-Head Attention)的核心思想是将输入向量分别映射为查询 Q、键 K、值 V,再按照头数切分到多个子空间中;每个头独立计算注意力分数并得到加权表示,最后拼接各头的结果,通过线性层 out 映射回输入维度,从而捕捉序列中多角度的相关性。

在使用时需要注意以下几点:

  1. 维度整除:要确保 dim_kdim_v 能被 heads 整除,否则 view 时会报错。
  2. 缩放因子:缩放应该基于每个头的维度 sqrt(dim_k_per_head),而不是整体的 dim_k
  3. Softmax 顺序:正确做法是 Softmax(QK^T / scale),不要写成 Softmax(QK^T) / scale
  4. 张量连续性permute 之后用 .contiguous().view(),否则可能报错;或者用 reshape 自动处理。
  5. 输出层作用self.out 的作用是把多头拼接后的结果重新映射回输入维度,保持层间维度一致。
http://www.dtcms.com/a/503122.html

相关文章:

  • TensorFlow2 Python深度学习 - 生成对抗网络(GAN)简介
  • 珠海网站品牌设计公司简介厦门网页
  • 房子网站有哪些在线企业查询系统
  • 临颖网站建设漳州做网站建设
  • Linux oops时进行panic
  • 【Docker】Docker Image(镜像)
  • 重生之我拿捏Linux——《三、shell脚本使用》
  • Altium Designer(AD24)Windows窗口功能总结
  • C++进阶:重载类型转换
  • SKY77645 导致的Rach failure问题
  • C++模版:模板初阶及STL简介
  • 微网站策划方案厦门的网站建设公司
  • 织梦网站404页面模板成都全网推广哪家专业
  • Solidity智能合约存储与数据结构精要
  • 生活化讲解Controller - 餐厅的“前台接待员“
  • AI大事记12:Transformer 架构——重塑 NLP 的革命性技术(下)
  • 微信公众号登录wordpress网站湛江企业网站怎么建设
  • 智慧校园总体解决方案PPT(98页)
  • ComfyUI-DynamiCrafterWrapper:开启ComfyUI动图创作新时代
  • 关于国家授时中心遭受美国国家安全局网络攻击事件的技术分析报告
  • PyGAD使用指南
  • 洛谷 B3841 [GESP202306 二级] 自幂数判断
  • 英诺赛科(02577.HK)
  • 做网站服务器什么配置如何建设一个简易的网站
  • 在网站上做承诺书工程平台公司做什么的
  • 深入学习Spring Boot框架
  • 深度拷贝详解
  • 李宏毅机器学习笔记21-26周汇总
  • 特别分享:IOPaint概念及基础知识
  • 【微服务】(2) 环境和工程搭建