当前位置: 首页 > news >正文

小杰-自然语言处理(four)——transformer系列——注意力机制

注意力机制

种注意力形式及核心逻辑:

  1. 加性注意力:通过神经网络计算相似度,适应不同维度向量,灵活性高。
  2. 点积注意力:直接算查询向量与键向量的点积,点积越大相似度越高,计算高效。
  3. 缩放点积注意力:点积除以键向量维度平方根,避免高维下数值过大,Transformer 核心组件。
1.2.1 加性注意力

加性注意力是注意力机制的一种形式,它通过计算两个输入向量的相似度来确定权重。对于给定的查询向量 q、键向量 k,加性注意力分数的计算过程如下:

代码实现

import torch
import torch.nn as nn
import  torch.nn.functional as F#加性注意力类
class AdditiveAttention(nn.Module):def __init__(self,hidden_dim):"""hidden_dim:"""super(AdditiveAttention,self).__init__()#线性变换矩阵 q kself.w_q=nn.Linear(hidden_dim,1)self.w_k=nn.Linear(hidden_dim,1)#权重向量v^Tself.v=nn.Linear(hidden_dim,1)def forward(self,q,k,v):Q=self.w_q(q) #[batch_size,q_num,hidden_dim]K=self.w_k(k) #[batch_size,k_num,hidden_dim]V=self.v(v)#2,加性计算sum_qk=Q+Ktanh_out=torch.tanh(sum_qk)# 3.与v^T相乘,计算注意力分数random_input = torch.randn(q.size(0),q.size(1), q.size(2))v_t=self.v(random_input)score= torch.bmm(tanh_out,v_t.transpose(1,2))# 4. softmax 归一化权重atten_weight=F.softmax(score,dim=-1)#维度归一化# 5. 权重与    V 加权求和,得到 contextcontext=torch.bmm(atten_weight,V)# [batch_size, q_num, hidden_dim]return context# 模拟输入(batch_size=1,简化维度)
q=torch.tensor([[[0.59,0.84], [0.55,0.71], [0.57,0.80]]], dtype=torch.float32)
k = torch.tensor([[[0.56,0.70], [0.58,0.81], [0.60,0.87]]], dtype=torch.float32)
v = k
#计算加性注意力
atten=AdditiveAttention(hidden_dim=2)
context = atten(q, k, v)
print("上下文向量 context:\n", context.shape)
1.2.3 缩放点积注意力

作用:通过缩放因子平衡高维向量的点积结果,使 softmax 分布更平滑,提升训练稳定性。这是 Transformer 模型的核心创新之一。

代码实现

import torch
import  torch.nn.functional as Fclass ScaledDotProductAttention(torch.nn.Module):def __init__(self):super(ScaledDotProductAttention,self).__init__()self.V=torch.nn.Linear(2,1)def forward(self,Q,K,V):# 获取键向量的维度d_k = K.size(-1)  # 即图示中的维度 d_k=2# 计算点积注意力分数:Q与K^T相乘scores=torch.bmm(Q,K.transpose(1,2))# 缩放点积:除以根号(d_k),防止梯度消失或爆炸scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))# 应用 softmax 计算注意力权重attention_weights = F.softmax(scaled_scores, dim=-1)# 权重与 V 相乘得到 contextV=self.V(V)context = torch.bmm(attention_weights, V)return context, attention_weightsq_before=torch.tensor([[[5.9, 8.4], [5.5, 7.1], [5.7, 8.0]]], dtype=torch.float32)
k_before = torch.tensor([[[5.6, 7.0], [5.8, 8.1], [6.0, 8.7]]], dtype=torch.float32)
# 线性变换(图示 Wq, Wk, Wv)
Wq = torch.randn((2, 2))  # 2×2 变换矩阵
Wk = torch.randn(2, 2)
Wv = torch.randn(2, 2)
Q=torch.bmm(q_before,Wq.unsqueeze(0).repeat(q_before.size(0),1,1))
K=torch.bmm(k_before,Wk.unsqueeze(0).repeat(k_before.size(0),1,1))
V = torch.bmm(k_before, Wv.unsqueeze(0).repeat(k_before.size(0), 1, 1))
#注意力模块并计算
attention = ScaledDotProductAttention()
context, attn_weights = attention(Q, K, V)print("缩放后的注意力权重矩阵:\n", attn_weights)
print("上下文向量 context:\n", context)
http://www.dtcms.com/a/524200.html

相关文章:

  • Java SpringAOP --- AOP的使用,AOP的源码
  • 阿里云渠道商:如何设置阿里云的安全组规则?
  • 网站设计速成如何让百度快速收录网站文章
  • 北京平台网站建设多少钱学院网站建设的特色
  • 外贸soho建站多少钱山东省住房和城乡建设厅官方网站
  • 芯科科技推出智能开发工具Simplicity Ecosystem软件开发套件开启物联网开发的新高度
  • 报错: lfstackPack redeclared in this block / go版本混乱,清理旧版本
  • 和鲸科技入选《大模型一体机产业图谱》,以一体机智驱科研、重塑教学
  • Go语言:关于怎么在线学习go语言的建议
  • 树 B树和B+树
  • 【arXiv2025】Real-Time Object Detection Meets DINOv3
  • 绍兴网站建设专业的公司4000-262-怎么在百度上发帖推广
  • AH2203输入12v输出3v 6v 9v/2A同步降压LED驱动器芯片
  • C如何调用Go
  • 使用Mathematica编写一个高效的Langevin方程求解器
  • 中国软件企业出海,为什么80%都选择这家服务商?
  • 《红黑树核心机制解析:C++ STL中map/set高效实现原理与工程实践》
  • Spring Boot 使用 Redis 实现消息队列
  • 从renderToString到hydrate,从0~1手写一个SSR框架
  • git报错no new changes、does not match any
  • 公司做网站的费用怎么做账望野作品
  • 【第五章:计算机视觉-项目实战之推荐/广告系统】2.粗排算法-(4)粗排算法模型多目标算法(Multi Task Learning)及目标融合
  • Prometheus 监控系统全维度指南
  • Gradle 增量构建与构建缓存:自定义 Task 如何实现 “只构建变化内容”?
  • 【笑脸惹桃花】1024,阶段性回望与小结
  • 农产品网站建设策划方案网站获取qq号码 代码
  • 网站服务器的作用和功能有哪些福田欧辉是国企吗
  • R语言高效数据处理-变量批量统计检验
  • 云图-地基云图
  • R语言基于selenium模拟浏览器抓取ASCO数据-连载NO.03