当前位置: 首页 > news >正文

Transformer中为什么要使用多头注意力?

参考视频:面试必刷:大模型为什么要使用多头注意力?_哔哩哔哩_bilibili

详解文章:Transformer内容详解(通透版)-CSDN博客

单头注意力的劣势:单头注意力只能从一个角度“看”输入序列,计算得到的注意力权重反映的是一种特定的关注模式。

多头注意力将注意力分为了多个“头”,每个头独立计算注意力,关注输入的不同子空间或不同方面的特征。这样,模型能够并行地捕捉到多种不同类型的语义关系

将输入投射到多个不同的低维空间,分别计算注意力,最后再concat拼接,通过线性变换融合丰富了模型的表达能力,使得Transformer能够学习复杂的组合特征。同时,每个注意力头的参数量和计算复杂度降低,有助于提升训练的稳定性和效率,有利于收敛

单头注意力:

import torch
import torch.nn as nnclass Self_Attention(nn.Module):def __init__(self, dim, dk, dv):super().__init__()self.scale = dk ** -0.5self.q = nn.Linear(dim, dk)self.k = nn.Linear(dim, dk)self.v = nn.Linear(dim, dv)def forward(self, x):# x: [batch, seq_len, dim]q = self.q(x)  # [batch, seq_len, dk]k = self.k(x)  # [batch, seq_len, dk]v = self.v(x)  # [batch, seq_len, dv]attn = (q @ k.transpose(-2, -1)) * self.scale  # [batch, seq_len, seq_len]attn = attn.softmax(dim=-1)out = attn @ v  # [batch, seq_len, dv]return out

多头注意力:

import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, dim, dk, dv, num_heads):super().__init__()self.num_heads = num_headsself.dk = dkself.dv = dvself.q_linear = nn.Linear(dim, dk * num_heads)self.k_linear = nn.Linear(dim, dk * num_heads)self.v_linear = nn.Linear(dim, dv * num_heads)self.out_linear = nn.Linear(dv * num_heads, dim)def forward(self, x):B, N, _ = x.shape  # batch, seq_len, dimQ = self.q_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)  # [B, heads, N, dk]K = self.k_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)V = self.v_linear(x).view(B, N, self.num_heads, self.dv).transpose(1, 2)# Attentionattn = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5)  # [B, heads, N, N]attn = attn.softmax(dim=-1)out = attn @ V  # [B, heads, N, dv]out = out.transpose(1, 2).reshape(B, N, self.num_heads * self.dv)  # [B, N, heads*dv]out = self.out_linear(out)  # [B, N, dim]return out

http://www.dtcms.com/a/392600.html

相关文章:

  • 《嵌入式硬件(十六):基于IMX6ULL的I2C的操作》
  • AI.工作助手.工作提效率
  • 【开题答辩全过程】以 Louis宠物商城为例,包含答辩的问题和答案
  • 微服务-网络模型与服务通信方式openfein
  • 如何快速定位局域网丢包设备?
  • 算法<java>——排序(冒泡、插入、选择、归并、快速、计数、堆、桶、基数)
  • 深入浅出CMMI:从混乱到卓越的研发管理体系化之路
  • Docker一键部署prometheus并实现飞书告警详解
  • 基于“开源AI大模型AI智能名片S2B2C商城小程序”的多平台资源位传播对直播营销流量转化的影响研究
  • 【设计模式】适配器模式 在java中的应用
  • 2013/07 JLPT听力原文 问题四
  • MyBatis 缓存体系剖析
  • MySQL 主从复制 + MyCat 读写分离 — 原理详解与实战
  • Vmake AI:美图推出的AI电商商品图编辑器,快速生成AI时装模特和商品图
  • Debian13 钉钉无法打开问题解决
  • 02.容器架构
  • Diffusion Model与视频超分(1):解读淘宝开源的视频增强模型Vivid-VR
  • 通过提示词工程(Prompt Engineering)方法重新生成从Ollama下载的模型
  • 有没有可以检测反爬虫机制的工具?
  • 大模型为什么需要自注意力机制?
  • 长度为K子数组中的最大和-定长滑动窗口
  • Linux安装Kafka(无Zookeeper模式)保姆级教程,云服务器安装部署,Windows内存不够可以看看
  • WEEX编译|续写加密市场叙事
  • 为 Element UI 表格增添排序功能
  • 点评项目(Redis中间件)第四部分缓存常见问题
  • 动态水印也能去除?ProPainter一键视频抠图整合包下载
  • DevSecOps 意识不足会导致哪些问题
  • LeetCode:27.合并两个有序链表
  • 适用于双节锂电池的充电管理IC选型参考
  • 格式说明符