当前位置: 首页 > news >正文

大模型中的三角位置编码实现

Transformer中嵌入表示 + 位置编码的实现

import torch
import math
from torch import nn# 词嵌入位置编码实现
class EmbeddingWithPosition(nn.Module):"""vocab_size:词表大小emb_size: 词向量维度seq_max_len: 句子最大长度 (人为设定,例如GPT2的最大长度是1024) """def __init__(self, vocab_size, emb_size, dropout=0.1, seq_max_len=5000):self.seq_emb = nn.Embedding(vocab_size, emb_size) # 序列中每个token的embedding向量表示#  位置编码实现 (硬编码方式)position_idx = torch.arange(0, seq_max_len, dtype=torch.float).unsqueeze(-1)position_emb_fill = position_idx * torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000.0) / emb_size) # 三角位置编码实现position_emb = torch.zeros(seq_max_len, emb_size) # 位置编码 emb_size是嵌入维度大小position_emb[:, 0::2] = torch.sin(position_emb_fill)position_emb[:, 1::2] = torch.cos(position_emb_fill)self.register_buffer('pos_encoding', position_emb) # 固定参数,不需要trainself.dropout = nn.Dropout(dropout)def forward(self, x):x = self.seq_emb(x) # 嵌入层表示 (batch_size, seq_len, emb_size)# x = x + self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # 添加位置编码x += self.pos_encoding.unsqueeze(0)return self.dropout(x)

自己动手实现易懂版本:

assert 10 % 2 == 0,  "wrong assert"
# 如果前面判断正确的话,则不会引发异常;否则,则会引发异常import torchimport torch
def creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim):assert dim % 2 == 0, "wrong dim"position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)omega = torch.arange(dim//2, dtype=torch.float)omega /= dim/2.omega = 1./(10000**omega)sita = n_pos_vec[:,None] @ omega[None,:]emb_sin = torch.sin(sita)emb_cos = torch.cos(sita)position_embedding[:,0::2] = emb_sinposition_embedding[:,1::2] = emb_cosreturn position_embeddingdef create_pe_absulute_sincos_embedding(n_pos_vec, dim):"""绝对位置编码:param n_pos_vec: 位置编码的长度向量:param dim: 词向量的维度:return: 位置编码"""assert dim % 2 == 0, "dim must be even"position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) # 三角函数位置编码omega = torch.arange(dim // 2, dtype=torch.float) # 0 ~ i, max_i: dim // 2omega *= 2omega /= dim omega = torch.pow(10000, omega) # 10000^(2i/dim)omega = 1 / omegaomega = omegaprint("n_pos_vec shape:",n_pos_vec.unsqueeze(1).shape)print("omega shape:", omega.shape).squeezeposition_embedding[:, 0::2] = torch.sin(n_pos_vec.unsqueeze(1) * omega) # 偶数位置position_embedding[:, 1::2] = torch.cos(n_pos_vec.unsqueeze(1) * omega) # 奇数位置return position_embeddingif __name__ == "__main__":n_pos = 4dim = 8n_pos_vec = torch.arange(n_pos, dtype=torch.float)position_embeddding = create_pe_absulute_sincos_embedding(n_pos_vec, dim)position_embeddding_1 = creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim)print(position_embeddding == position_embeddding_1)print("position embedding shape:", position_embeddding.shape)

参考版本

相关文章:

  • WinCC V7.2到V8.0与S71200/1500系列连接通讯教程以及避坑点
  • C++学习之模板初阶学习
  • 数据治理框架在企业中的落地:从理念到实践
  • 第三章 Freertos智能小车遥控控制
  • 互联网大厂Java面试实录:Spring Boot与微服务架构在电商场景中的应用解析
  • 21.【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--单体转微服务--身份认证服务拆分规划
  • diy装机成功录
  • C++ learning day 02
  • day010-命令实战练习题
  • 第一个SpringBoot程序
  • 软考中级数据库备考-上午篇
  • Spark的三种部署模式及其特点与区别
  • Autoware播放提示音
  • 基于Spring Boot + Vue的高校心理教育辅导系统
  • adb命令查询不到设备?
  • QTableWidget实现多级表头、表头冻结效果
  • 模型 启动效应
  • WPF之集合绑定深入
  • 配置高级相关
  • 深入理解卷积神经网络的输入层:数据的起点与预处理核心
  • 体验中国传统文化、采购非遗文创,波兰游客走进上海市群艺馆
  • 2025上海十大动漫IP评选活动启动
  • 习近平会见缅甸领导人敏昂莱
  • 中华人民共和国和俄罗斯联邦关于进一步加强合作维护国际法权威的联合声明
  • 轿车追尾半挂车致3死1伤,事故调查报告:司机过分依赖巡航系统
  • 央视315晚会曝光“保水虾仁”后,湛江4家涉事企业被罚超800万元