当前位置: 首页 > news >正文

multi-head attention 多头注意力实现细节

论文中关于多头注意力的描述

1706.03762

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"# Linear projections for Q, K, Vself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):# Q, K, V: (batch_size, num_heads, seq_len, d_k)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attn = F.softmax(scores, dim=-1)return torch.matmul(attn, V)def split_heads(self, x, batch_size):# x: (batch_size, seq_len, d_model)x = x.view(batch_size, -1, self.num_heads, self.d_k)  # (batch_size, seq_len, num_heads, d_k)return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)def combine_heads(self, x, batch_size):# x: (batch_size, num_heads, seq_len, d_k)x = x.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, d_k)return x.view(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)def forward(self, Q, K, V, mask=None):batch_size = Q.size(0)Q = self.W_q(Q)  # (batch_size, seq_len, d_model)K = self.W_k(K)V = self.W_v(V)Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_len, d_k)K = self.split_heads(K, batch_size)V = self.split_heads(V, batch_size)attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.combine_heads(attn_output, batch_size)return self.W_o(output)  # Final linear projection

会发现其实代码和论文不是完全一样的,论文看起来是每个头有单独的W去乘,但是代码里是所有头共用W再拆分。其实两者是等价的。要注意一下,在multi-head attention中,输入是不被拆分的,它的shape一直是[L,D_model],拆分的是W,把[D_model, D_model]的矩阵拆分成K个[D_k, D_model]的矩阵。

根据矩阵的乘法定义

Y = X W = X [W₁  W₂] = [X W₁   X W₂]

乘之前拆分还是乘之后拆分,是一样的。代码用大矩阵来乘,可以加快计算。

http://www.dtcms.com/a/482216.html

相关文章:

  • 第七章 完整的模型训练
  • 08-Vue3组合式API最佳实践指南
  • 山东网站建设网站做全世界的生意的网站
  • 【文献分享】通过基于大型语言模型嵌入的蛋白质的 k 均值聚类来探索同源性检测
  • Redis 黑马点评-优惠券秒杀
  • 网站页面示意图怎么做宁波北仑做网站
  • ffmpeg转化mp3至wav格式
  • 不同类型的 3D 文件格式
  • ElasticSearch 实战:全文检索与数据聚合分析的完整指南​
  • Day62 设备驱动程序开发基础与LED控制
  • 支持Word (doc/docx) 和 PDF 转成一张垂直拼接的长PNG图片工具类
  • JAVA同城预约服务家政服务美容美发洗车保洁搬家维修家装系统源码小程序+公众号+h5
  • 正规拼多多代运营公司如何优化网站结构
  • 三层前馈神经网络实战:MNIST手写数字识别
  • 深度学习(四)
  • 学习HAL库STM32F103C8T6(MQTT报文)
  • 【C++】C++11特性学习(1)——列表初始化 | 右值引用与移动语义
  • 网站布局 种类手机商城页面设计
  • 如何建设手机端网站电力公司建设安全文化
  • 红色 VR 大空间:技术赋能红色文化传承的运营价值与实践路径
  • 网络协议工程 - eNSP及相关软件安装 - [eNSP, VirtualBox, WinPcap, Wireshark, Win7]
  • WHAT - 前端性能指标(交互和响应性能指标)
  • 专业的媒体发稿网
  • dede旅游网站模板wordpress教学主题
  • 做网站的技术性说明怎么自己做微网站吗
  • VScode安装以及C/C++环境配置20251014
  • 黄页网站大全通俗易懂wordpress 数据库配置错误
  • 常规的红外工业镜头有哪些?能做什么?
  • 一文读懂分子结合位点的预测:为双荧光素酶实验铺路
  • SM4密码核心知识点