当前位置: 首页 > news >正文

小杰-自然语言处理(seven)——transformer系列——自注意力(self-attention)

1.引入

Self - attention(自注意力)是一种注意力机制,其核心在于同一输入序列生成 Q、K、V。通过计算序列内元素间的相关性(Q 与 K 的权重矩阵),对元素(V)加权求和得到新的上下文表示,能捕捉长距离依赖,是 Transformer 的核心组件,广泛应用于 NLP、CV 等领域。

2.1 self-attention矩阵运算过程

假设输入序列有 3 个词,词向量维度 2(输入矩阵 X 为 3×2)。随机初始化 3 个 2×4 的权重矩阵 W_Q、W_K、W_V,与 X 相乘得到 3×4 的 Q、K、V。

计算 Q 与 K 转置的点乘,得 3×3 注意力分数矩阵,除以(其中

缩放后,经 softmax 得到权重矩阵,再与 V 相乘,得到 3×4 的输出矩阵 context,即为 Self-Attention 结果,融合了序列内各元素信息。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
#定义Self-Attention 类
class SelfAttention(nn.Module):def __init__(self,input_dim,qkv_dim):"""input_dim: 输入向量维度(如图中每个 x 的维度,假设输入是单个向量拼接,这里设为图中隐含的维度,比如示例里可理解为每个 x 是 2 维,实际按需求调整)qkv_dim: Q、K、V 的维度(对应图中 W_Q、W_K、W_V 输出维度,图中示意生成的 Q/K/V 是更高维,这里设为 4 演示,可按需改)"""super(SelfAttention,self).__init__()# 初始化 Q、K、V 对应的权重矩阵self.W_Q=nn.Linear(input_dim,qkv_dim)self.W_K = nn.Linear(input_dim, qkv_dim)self.W_V = nn.Linear(input_dim, qkv_dim)def forward(self,x):"""x: 输入序列,形状为 [seq_len, input_dim],对应图中 x0、x1、x2 拼接后的输入,这里 seq_len=3(3 个元素),input_dim 是单个元素维度"""Q=self.W_Q(x)# [3, qkv_dim]  对应图中 Q 矩阵,每个 x 映射到 Q 空间K = self.W_K(x)# [3, qkv_dim]  对应图中 K 矩阵,每个 x 映射到 K 空间V = self.W_V(x) # [3, qkv_dim]  对应图中 V 矩阵,每个 x 映射到 V 空间# 2. 计算注意力分数(Q 与 K^T 点乘)scores = torch.matmul(Q, K.transpose(0, 1))# 3. 缩放(可选,图中未体现但实际常用,这里按公式加一下,d_k 取 qkv_dim 开根号)d_k = Q.size(-1)  # 获取 K 的维度,即 qkv_dimscores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))# 4. 计算注意力权重(softmax 归一化),得到 [3, 3] 权重矩阵,对应图中“注意力权重”attention_weights = F.softmax(scores, dim=-1)# 5. 加权求和得到 context(权重与 V 相乘),结果形状 [3, qkv_dim],对应图中“context”context = torch.matmul(attention_weights, V)return context, attention_weightsif __name__ == '__main__':#模拟输入input_seq=torch.tensor([[1.1, 2.2],  # x0[3.4, 4.4],  # x1[5.4, 6.4]], # x2dtype=torch.float32)# 初始化 Self - Attention 模块self_attention=SelfAttention(input_dim=2, qkv_dim=4)# 前向传播context, attention_weights = self_attention(input_seq)# 打印结果查看print("Q 矩阵:\n", self_attention.W_Q(input_seq).shape)print("K 矩阵:\n", self_attention.W_K(input_seq).shape)print("V 矩阵:\n", self_attention.W_V(input_seq).shape)print("注意力分数矩阵:\n",torch.matmul(self_attention.W_Q(input_seq), self_attention.W_K(input_seq).transpose(0, 1)).shape)print("注意力权重矩阵:\n", attention_weights.shape)print("最终 context 输出:\n", context.shape)

实验结果为:

http://www.dtcms.com/a/528167.html

相关文章:

  • 带代码示例的 HTML 标签实操手册
  • fastapi 前端文件配置 python fastapi服务添加前端打包好的静态文件
  • 计算机网络自顶向下方法3——详解分组交换时延、丢包、吞吐量
  • 外贸 网站推广青岛胶南市城乡建设局网站
  • GitHub等平台形成的开源文化正在重塑推荐人
  • webrtc代码走读(六)-QOS-FEC冗余度配置
  • F036 vue+flask中医热性药知识图谱可视化系统vue+flask+echarts+mysql
  • 电脑效能跃升利器 金士顿KVR内存焕新机
  • Bootstrap UI 编辑器
  • MySQL联合查询详解
  • webrtc代码走读(五)-QOS-FEC原理
  • 车载诊断架构 ---DTC快照中DID大小顺序是怎么要求的?
  • Windows 10 下 VS Code 配置 C++ 开发环境(MinGW)
  • 天津低价网站建设怎样做淘宝联盟的网站
  • 福建网站建建设方案太原关键词优化报价
  • 深耕 Rust:核心技术解析、生态实践与高性能开发指南
  • 深入浅出 Tokio 源码:掌握 Rust 异步编程的底层逻辑
  • 北京网站建设管庄1天学会搭建营销网站
  • 基于SEH的异常捕获与MiniDumpWriteDump深度解析
  • C语言练习题
  • Postman应用实战
  • Vue-Loader 深度解析:原理、使用与最佳实践
  • HCIP第二次作业(VRRP/STP/VLAN/Eth-trunk/NAT)
  • 外国设计网站推荐自己学网站建设
  • ASP.NET Core中创建中间件的几种方式
  • Docker安装思源笔记使用指南
  • 需求登记网站怎么做免费高清图片素材网站推荐
  • SpringBoot集成Elasticsearch | Java High Level Rest Client(HLRC)方式
  • 《神领物流》day07-线路规划之线路管理_完整代码【简单易懂注释版】
  • 使用Ansys Polyflow对泡沫聚合物挤出进行建模