当前位置: 首页 > news >正文

手写self-attention的三重境界

引言

self-attention在实现过程中有很多细节,不同的面试对self-attention实现的要求也不一样。所以我们要学会多种self-attention实现的方式,以此来告诉面试官,我们是了解self-attention的细节的。

self-attention的公式

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}( \frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V

代码实现

第一重境界:简化版本

import math
import torch
import torch.nn as nnclass SelfAttentionV1(nn.Module):def __init__(self, hidden_dim: int = 728) -> None:super().__init__()self.hidden_dim = hidden_dim# 初始化三个不同的线性应用层self.query_proj = nn.Linear(hidden_dim, hidden_dim)self.key_proj = nn.Linear(hidden_dim, hidden_dim)self.value_proj = nn.Linear(hidden_dim, hidden_dim)def forward(self, x):# x shape is: (batch_size, seq_len, hidden_dim)# 获取不同的Q, K, VQ = self.query_proj(x)K = self.key_proj(x)V = self.value_proj(x)# Q, K, V shape: (batch_size, seq_len, hidden_dim)# (batch_size, seq_len, hidden_dim) * (batch_size, hidden_dim, seq_len) = (batch_size, seq_len, seq_len)attention_value = torch.matmul(Q, K.transpose(-1, -2))# 计算注意力分数attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)# 计算结果 shape: (batch_size, seq_len, hidden_dim)output = torch.matmul(attention_weight, V)return output

第一重境界比较简单,完全对着公式实现就可以了。

第二重境界:效率优化

对QKV矩阵进行合并,然后再拆分

class SelfAttentionV2(nn.Module):def __init__(self, hidden_dim):super().__init__()self.hidden_dim = hidden_dimself.proj = nn.Linear(hidden_dim, hidden_dim * 3)def forward(self, x):# X shape: (batch_size, seq_len, hidden_dim)# QKV shape (batch_size, seq_len, hidden_dim*3)QKV = self.proj(X)Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)attention_weight = torch.softmax(torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim), dim=-1)output = attention_weight @ Vreturn output

第三重:加入一些细节(面试写法)

除了公式外,还有一些细节:

  • 加入dropout
  • 每个句子长度不一样,加入attention mask
  • output矩阵映射
class SelfAttentionV3(nn.Module):def __init__(self, hidden_dim, dropout_rate=0.1):super().__init__()self.hidden_dim = hidden_dimself.proj = nn.Linear(hidden_dim, hidden_dim * 3)self.attention_dropout = nn.Dropout(dropout_rate)self.output_proj = nn.Linear(hidden_dim, hidden_dim)def forward(self, x, attention_mask=None):# x shape: (batch_size, seq_len, hidden_dim)QKV = self.proj(x)Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)# 如果attention_mask不是None,那就要给那些被mask掉的词语一个非常非常小的值,这样做完softmax以后这些值就是0if attention_mask is not None:attention_weight = attention_weight.masked_fill(attention_mask == 0,float("-1e20"))attention_weight = torch.softmax(attention_weight, dim=-1)# 做dropoutattention_weight = self.attention_dropout(attention_weight)attention_result = attention_weight @ Voutput = self.output_proj(attention_result)return output

从 V1 到 V3 的核心优化脉络(迭代逻辑)

  1. 第一阶段:工程效率优化(V1 → V2)
    • 优化点:将 3 个独立线性层合并为 1 个合并线性层,再 split 拆分 QKV。
    • 核心逻辑:数学上完全等价(仅权重拼接),但减少了内核启动次数、内存碎片化,提升硬件并行效率(GPU 更易利用批量矩阵乘法算力)。
    • 价值:从 “教学级冗余实现” 转向 “工程化高效实现”,无性能损失,仅提升效率。
  2. 第二阶段:功能完整性优化(V2 → V3)
    • 优化点 1:新增 attention_mask 支持(修正笔误后用 masked_fill)。
    • 解决问题:适配实际场景(NLP 批量 Padding、生成任务因果掩码),屏蔽无效位置干扰。
    • 优化点 2:新增注意力权重 Dropout。
    • 解决问题:正则化,防止模型过度依赖少数关键位置,缓解过拟合。
    • 优化点 3:新增输出线性投影 output_proj。
    • 解决问题:对注意力聚合后的特征做 “精炼”,增强模型表达能力,适配深层网络堆叠。
    • 价值:从 “仅追求效率” 转向 “可落地工业级功能”,覆盖批量训练、泛化能力等核心需求。
http://www.dtcms.com/a/597446.html

相关文章:

  • 功能安全/ASPICE合规保障:高效模型测试驱动零缺陷开发
  • k8s DaemonSet 控制器从原理到实践
  • 睢宁做网站公司WordPress同步某个表
  • Note:高电压工况下温度测量:挑战与应对策略全解析
  • PostgreSQL 实战分析:UPDATE 语句性能异常与缓存击穿诊断
  • java接口自动化之allure本地生成报告
  • 基于spring boot房屋租赁管理系统的设计与实现
  • Android中使用SQLCipher加密GreenDao数据库不成功
  • AI泡沫量化预警:基于多因子模型的1999年互联网泡沫历史回溯与风险映射
  • 网站建设多少钱一个平台wordpress 查看菜单
  • 网站导航设置婚恋网站建设教程
  • 黑马JAVAWeb - Maven高级-分模块设计与开发-继承-版本锁定-聚合-私服
  • 34.来自Transformers的双向编码器表示(BERT)
  • 风啸之上,科技为盾——VR台风避险体验
  • 免费个人网站域名外贸wordpress模板下载
  • 如何在PHP框架中高效处理HTTP请求:从基础到最佳实践!
  • 语义抽取逻辑概念
  • 【大数据技术06】大数据技术
  • 即刻搜索收录网站重庆网站建设推广优化
  • 高明骏域网站建设特定ip段访问网站代码
  • 数组有哪些算法?
  • PCB之电源完整性之电源网络的PDN仿真CST---07
  • 学校网站的页头图片做有没有专业做咖啡店设计的网站
  • Dify Docker Compose 安装指南
  • Spring Boot 2.x 集成 Knife4j (OpenAPI 3) 完整操作指南
  • 郑州企业网站模板建站中国建设银行大学助学贷款网站
  • 微信 网站模板网站毕业设计图怎么做
  • RTMP推流平台EasyDSS:视频推拉流技术赋能幼儿园安全可视化与家园共育新实践
  • iChat:RabbitMQ封装
  • 悬镜安全CEO子芽荣获“2025年度OSCAR开源人物”