当前位置: 首页 > news >正文

self attention, masked self attention, cross attention

1. 普通 Self-Attention(缩放点积)

  • Q K V是根据同一个输入X
  • 没有约束,所有位置都可以互相关注
2. Masked Self-Attention

其中 MM 是 掩码矩阵(mask matrix),定义为:

这意味着:第 ii个位置只能关注第 $1到到i个位置(含自己),不能看到个位置(含自己),不能看到i+1, ..., T$ 的未来 token。

假设序列长度为 4,True 表示被 mask 掉(不可见),False 表示可见。

这个叫做 upper triangular mask(上三角掩码),也叫 causal mask(因果掩码)

3. Cross Attention

Q来自一个源(target), K, V来自另一个源(source)。

Q和K,V的length可能不一样,但是d_model是一样的

代码实现多头注意力机制:

其实可以看到代码和论文是有些出入的,论文是先分头,然后每个head都分别有Wq, Wk, Wv,但是在代码中是先共用Wq, Wk, Wv,然后再分头,只做QKV矩阵乘法,再concatenate

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"# Linear projections for Q, K, Vself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):# Q, K, V: (batch_size, num_heads, seq_len, d_k)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attn = F.softmax(scores, dim=-1)return torch.matmul(attn, V)def split_heads(self, x, batch_size):# x: (batch_size, seq_len, d_model)x = x.view(batch_size, -1, self.num_heads, self.d_k)  # (batch_size, seq_len, num_heads, d_k)return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)def combine_heads(self, x, batch_size):# x: (batch_size, num_heads, seq_len, d_k)x = x.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, d_k)return x.view(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)def forward(self, Q, K, V, mask=None):batch_size = Q.size(0)Q = self.W_q(Q)  # (batch_size, seq_len, d_model)K = self.W_k(K)V = self.W_v(V)Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_len, d_k)K = self.split_heads(K, batch_size)V = self.split_heads(V, batch_size)attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.combine_heads(attn_output, batch_size)return self.W_o(output)  # Final linear projection

需要对比一下,看输出一不一样

http://www.dtcms.com/a/479742.html

相关文章:

  • 基于51单片机心率温度语音播报、显示时间
  • 商城网站建设公司招聘北京房产网北京二手房
  • 前端图片加载失败、 img 出现裂图的原因全解析
  • Linux——进程优先级
  • 宝塔面板建设二级域名网站访问不了网站建设需
  • wordpress 会被墙吗福田企业网站优化排名
  • ftp上传网站社群营销与运营
  • 吉林网站建设费用接推广是什么意思
  • 如何做营销型网站人工智能的关键词
  • 旅游网站的设计栏目动易网站系统怎么样
  • 上海网站制作建设多少钱.net 快速网站开发
  • 网站营销活动策划wordpress制作网页教程
  • 如何完整保存网站并做修改中企动力邮箱网页版
  • 网站模板下载破解版廊坊那家做网站排行榜
  • 深圳网站设计九曲个人小说网站怎么做
  • 沈阳网站建设024wwordpress 会员登录
  • 【开题答辩全过程】以 濒危动物保护管理系统为例,包含答辩的问题和答案
  • VS Code 智能提示(IntelliSense)完全配置指南(C++/Python/JavaScript)
  • wordpress 全站密码网站建设基础策划
  • 长沙 网站seo服务 网络服务wordpress实时预览载入中
  • 简述电子商务网站的建设流程图企业网站建设 西宁
  • 网页设计与网站建设 入门必练青海网站建设哪家好
  • 【HashMap全面知识点】— 快速理解HashMap
  • 【系统分析师】写作框架:面向对象设计方法及其应用
  • 图书网站建设实训总结人像摄影网站十大排名
  • 网站开发常问的技术性问题哈尔滨建站模板厂家
  • 国内信息图制作网站有哪些网站开发的技术支撑 经验能力
  • 上海缔客网站建设公司婚纱摄影网站
  • 湖北省建设厅官方网站网页传奇游戏哪个好
  • 河南省建设工程标准定额管理网站福建seo搜索引擎优化