当前位置: 首页 > news >正文

自注意力(Self-Attention)和位置编码

自注意力

  • 给定序列 x 1 , … , x n \mathbf{x}_1, \ldots, \mathbf{x}_n x1,,xn, ∀ x i ∈ R d \forall \mathbf{x}_i \in \mathbb{R}^d xiRd

  • 自注意力池化层将 x i \mathbf{x}_i xi 当做key, value, query来对序列抽取特征得到 y 1 , … , y n \mathbf{y}_1, \ldots, \mathbf{y}_n y1,,yn, 这里

    y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d yi=f(xi,(x1,x1),,(xn,xn))Rd
    在这里插入图片描述
    与 CNN、RNN 的比较
    在这里插入图片描述

CNNRNN自注意力
计算复杂度O( k n d 2 knd^2 knd2)O( n d 2 nd^2 nd2)O( n 2 d n^2d n2d)
并行度O( n n n)O( 1 1 1)O( n n n)
最长路径O( n / k n/k n/k)O( n n n)O( 1 1 1)

位置编码

  • 跟CNN/RNN不同,自注意力并没有记录位置信息
  • 位置编码将位置信息注入到输入里
    • 假设长度为 n n n 的序列是 X ∈ R n × d \mathbf{X} \in \mathbb{R}^{n \times d} XRn×d,那么使用位置编码矩阵 P ∈ R n × d \mathbf{P} \in \mathbb{R}^{n \times d} PRn×d 来输出 X + P \mathbf{X} + \mathbf{P} X+P 作为自编码输入
  • P \mathbf{P} P 的元素如下计算:
    p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) , p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) p_{i,2j} = \sin\left(\frac{i}{10000^{2j/d}}\right), \quad p_{i,2j+1} = \cos\left(\frac{i}{10000^{2j/d}}\right) pi,2j=sin(100002j/di),pi,2j+1=cos(100002j/di)

位置编码矩阵

  • P ∈ R n × d \mathbf{P} \in \mathbb{R}^{n \times d} PRn×d: p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) , p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) p_{i,2j} = \sin\left(\frac{i}{10000^{2j/d}}\right), \quad p_{i,2j+1} = \cos\left(\frac{i}{10000^{2j/d}}\right) pi,2j=sin(100002j/di),pi,2j+1=cos(100002j/di)

相对位置信息

  • 位于 i + δ i+\delta i+δ 处的位置编码可以线性投影位置 i i i 处的位置编码来表示

  • ω j = 1 / 1000 0 2 j / d \omega_j = 1/10000^{2j/d} ωj=1/100002j/d,那么在这里插入图片描述

总结

  • 自注意力池化层将 x i \mathbf{x}_i xi 当做key, value, query来对序列抽取特征
  • 完全并行、最长序列为1、但对长序列计算复杂度高
  • 位置编码在输入中加入位置信息,使得自注意力能够记忆位置信息

代码实现

首先导入必要的环境

import math
import torch
from torch import nn
from d2l import torch as d2l

自注意力

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()

在这里插入图片描述
位置编码

#@save
class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, num_hiddens, dropout, max_len=1000):"""初始化位置编码类参数:num_hiddens: int编码的隐藏维度大小(即每个位置的编码维度)dropout: floatDropout的概率,用于防止过拟合max_len: int, 默认值为1000最大序列长度,用于生成足够长的位置编码矩阵"""super(PositionalEncoding, self).__init__()# 定义Dropout层,用于在前向传播中随机丢弃部分神经元self.dropout = nn.Dropout(dropout)# 创建一个形状为 (1, max_len, num_hiddens) 的位置编码矩阵 P# 其中 1 表示批量维度,max_len 表示序列长度,num_hiddens 表示编码维度self.P = torch.zeros((1, max_len, num_hiddens))# 生成位置索引的张量,形状为 (max_len, 1)# 每个位置索引除以 10000 的幂次,幂次由编码维度决定X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)# 对编码维度的偶数索引位置应用正弦函数self.P[:, :, 0::2] = torch.sin(X)# 对编码维度的奇数索引位置应用余弦函数self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):"""前向传播函数,将位置编码添加到输入张量 X 上参数:X: torch.Tensor输入张量,形状为 (batch_size, seq_len, num_hiddens)返回:torch.Tensor添加了位置编码的张量,形状与输入张量相同"""# 将位置编码矩阵 P 的前 seq_len 个位置与输入张量 X 相加# 并将 P 移动到与 X 相同的设备(如 GPU 或 CPU)X = X + self.P[:, :X.shape[1], :].to(X.device)# 应用 Dropout 并返回结果return self.dropout(X)

行代表标记在序列中的位置,列代表位置编码的不同维度

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

在这里插入图片描述
在编码维度上降低频率
在这里插入图片描述

相关文章:

  • ByteArrayOutputStream 类详解
  • 在Java中,什么是checked exception(即compile-time exception、编译时异常)?
  • 【学习笔记】机器学习(Machine Learning) | 第五章(3)| 分类与逻辑回归
  • Go小技巧易错点100例(三十)
  • DEX平台引领风尚 XBIT让数字资产回归简单与透明
  • 乐视系列玩机------乐视pro3精英版-x722的一些刷机救砖教程与固件资源
  • Gateway网关:路由和鉴权
  • Android控件View、ImageView、WebView用法
  • QT 在圆的边界画出圆
  • Python打造智能化多目标车辆跟踪系统:从理论到实践
  • LeetCode 热题 100 70. 爬楼梯
  • python读取图片自动旋转的问题解决
  • 深入解析:删除有序数组中的重复项 II——巧用双指针实现条件筛选
  • 【Leetcode 每日一题 - 补卡】838. 推多米诺
  • 掌握流量管理:利用 EKS Ingress 和 AWS 负载均衡器控制器
  • 用户模块 - IP归属地技术方案
  • TCP/IP协议深度解析:从分层架构到TCP核心机制
  • MySQL 复合查询
  • Spring AMQP源码解析
  • 英伟达语音识别模型论文速读:Fast Conformer
  • 山大齐鲁医院通报“子宫肌瘤论文现男性患者”:存在学术不端
  • 韩国总统选举民调:共同民主党前党首李在明支持率超46%
  • 马丽称不会与沈腾终止合作,“他是我的恩人,也是我的贵人”
  • 包揽金银!王宗源、郑九源夺得跳水世界杯总决赛男子3米板冠亚军
  • 中国队夺跳水世界杯总决赛首金
  • 印巴局势紧张或爆发军事冲突,印度空军能“一雪前耻”吗?