当前位置: 首页 > news >正文

TR3--Transformer之pytorch复现

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/o-DaK6aQQLkJ8uE4YX1p3Q) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

import math
import torch
import torch.nn as nn
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
import torch.nn as nn
class Transpose(nn.Module):def __init__(self,*dims,contiguous=False):super().__init__()self.dims=dimsself.contiguous=contiguousdef forward(self,x):if self.contiguous:return x.transpose(*self.dims).contiguous()else:return x.transpose(*self.dims)  # 转换形式
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):def __init__(self,d_k:int):super().__init__()self.d_k=d_kdef forward(self,q,k,v,mask=None):# 计算注意力的核心步骤scores=torch.matmul(q,k) #计算矩阵# 缩放分数scores=scores/(self.d_k**0.5)if mask is not None:scores.masked_fill_(mask,-1e9)attn=F.softmax(scores,dim=-1) #分数应用softmax得到注意力权重context=torch.matmul(attn,v) #根据注意力权重加权求和值向量return context 
class MultiHeadAttention(nn.Module):def __init__(self,d_model,n_heads):super().__init__()self.d_k=d_model//n_headsself.d_v=d_model//n_headsself.n_heads=n_headsself.W_Q=nn.Linear(d_model,self.d_k*n_heads,bias=False)self.W_K=nn.Linear(d_model,self.d_k*n_heads,bias=False)self.W_V=nn.Linear(d_model,self.d_v*n_heads,bias=False)self.W_O=nn.Linear(n_heads*self.d_v,d_model,bias=False)def forward(self,Q,K,V,mask=None):bs=Q.size(0)q_s=self.W_Q(Q).view(bs,-1,self.n_heads,self.d_k).transpose(1,2)k_s=self.W_K(K).view(bs,-1,self.n_heads,self.d_k).permute(0,2,3,1)v_s=self.W_V(V).view(bs,-1,self.n_heads,self.d_v).transpose(1,2)context=ScaledDotProductAttention(self.d_k)(q_s,k_s,v_s)context=context.transpose(1,2).contiguous().view(bs,-1,self.n_heads*self.d_v)output=self.W_O(context)return output 
class Feedforward(nn.Module):def __init__(self,d_model,d_ff,dropout=0.1):super().__init__()self.linear1=nn.Linear(d_model,d_ff)self.dropout=nn.Dropout(dropout)self.linear2=nn.Linear(d_ff,d_model)def forward(self,x):x=torch.nn.functional.relu(self.linear1(x))x=self.dropout(x)x=self.linear2(x)return x 
class PositionalEncoding(nn.Module):def __init__(self,d_model,dropout,max_len=5000):super().__init__()self.dropout=nn.Dropout(p=dropout)pe=torch.zeros(max_len,d_model).to(device)position=torch.arange(0,max_len).unsqueeze(1)div_term=torch.exp(torch.arange(0,d_model,2)*(math.log(10000.0)/d_model))pe[:,0::2]=torch.sin(position*div_term) #什么意思,pe[:,1::2]=torch.cos(position*div_term) #计算PE(pos,2i+1)pe=pe.unsqueeze(0) #计算self.register_buffer('pe',pe)def forward(self,x):print(x.device)x = x + self.pe[:x.size(1), :].transpose(0, 1).to(device)print(x.device)return self.dropout(x)
class EncoderLayer(nn.Module):def __init__(self,d_model,n_heads,d_ff,dropout=0.1):super().__init__()self.self_attn=MultiHeadAttention(d_model,n_heads)self.feedforward=Feedforward(d_model,d_ff,dropout)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)self.dropout=nn.Dropout(dropout)def forward(self,x,mask):attn_output=self.self_attn(x,x,x,mask)x=x+self.dropout(attn_output)x=self.norm1(x)ff_output=self.feedforward(x)x=x+self.dropout(ff_output)x=self.norm2(x)return x
class DecoderLayer(nn.Module):def __init__(self,d_model,n_heads,d_ff,dropout=0.1):super().__init__()self.self_attn=MultiHeadAttention(d_model,n_heads)self.enc_attn=MultiHeadAttention(d_model,n_heads)self.feedforward=Feedforward(d_model,d_ff,dropout)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)self.norm3=nn.LayerNorm(d_model)self.dropout=nn.Dropout(dropout)def forward(self,x,enc_output,self_mask,context_mask):attn_output=self.self_attn(x,x,x,self_mask)x=x+self.dropout(attn_output)x=self.norm1(x)attn_output=self.enc_attn(x,enc_output,enc_output,context_mask)x=x+self.dropout(attn_output)x=self.norm2(x)ff_output=self.feedforward(x)x=x+self.dropout(ff_output)x=self.norm3(x)return x
# 构建 
class Transformer(nn.Module):def __init__(self,vocab_size,d_model,n_heads,n_encoder_layers,n_decoder_layers,d_ff,dropout=0.1):super().__init__()self.embedding=nn.Embedding(vocab_size,d_model)self.positional_encoding=PositionalEncoding(d_model,dropout)self.encoder_layers=nn.ModuleList([EncoderLayer(d_model,n_heads,d_ff,dropout) for _ in range(n_encoder_layers)])self.decoder_layers=nn.ModuleList([DecoderLayer(d_model,n_heads,d_ff,dropout) for _ in range(n_decoder_layers)])self.fc_out=nn.Linear(d_model,vocab_size)self.dropout=nn.Dropout(dropout)def forward(self,src,trg,src_mask,trg_mask):src=self.embedding(src)src=self.positional_encoding(src)trg=self.embedding(trg)trg=self.positional_encoding(trg)for layer in self.encoder_layers:src=layer(src,src_mask)for layer in self.decoder_layers:trg=layer(trg,src,trg_mask,src_mask)output=self.fc_out(trg)return output 
#使用示例
vocab_size =10000 #假设词汇表大小为10000
d_model=512
n_heads=8
n_encoder_layers=6
n_decoder_layers=6
d_ff = 2048
dropout =0.1transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout)#定义输入,这里的输入是假设的,需要根据实际情况修改
src=torch.randint(0,vocab_size,(32,10)) #源语言句子
trg=torch.randint(0,vocab_size,(32,20)) #目标语言句子#掩码,用于屏蔽填充的位置
src_mask=(src !=0).unsqueeze(1).unsqueeze(2)
trg_mask =(trg !=0).unsqueeze(1).unsqueeze(2) #掩码,用于屏蔽填充的位置print("实际|输入数据维度:",src.shape)
print("预期|输出数据维度:",trg.shape)
output =transformer_model(src,trg,src_mask,trg_mask)
print("实际|输出数据维度:",output.shape)

实际|输入数据维度: torch.Size([32, 10])
预期|输出数据维度: torch.Size([32, 20])
实际|输出数据维度: torch.Size([32, 20, 10000])

http://www.dtcms.com/a/495194.html

相关文章:

  • Traccar本地文件包含漏洞(CVE-2025-61666)
  • 建站网站推荐icp域名备案查询系统
  • 智能美颜引擎:美颜SDK如何实现自适应芯片性能优化
  • Java中的boolean与Boolean
  • Flutter高级进阶教程(视频教程)
  • Rocketmq 分布式事务 两阶段提交
  • 骑行,团骑和独骑冲突吗?
  • 对网站和网页的认识鞍山信息网便民信息
  • 《算法通关指南---C++编程篇(2)》
  • 【论文速递】2025年第29周(Jul-13-19)(Robotics/Embodied AI/LLM)
  • 网站 模板更改网站备案
  • VR反诈一体机-VR预防诈骗模拟系统-VR防诈骗体验馆方案
  • 大型网站seo课程沈阳关键词优化费用
  • Kubernetes PVC 扩容完全指南:静态迁移 vs 动态扩容
  • 【题解】B2613【深基1.习5】打字速度
  • Elastic DevRel 通讯 — 2025 年 10 月
  • Java面试基础题
  • 博客标题:快速解决 VS Code 终端运行 petalinux-config 界面显示错乱问题
  • 强化学习【Monte Carlo Learning][MC Basic 算法]
  • 杭州网站开发制作公司小程序源码出售
  • 从0到1学习Qt -- 创建项目
  • dw做网站基础wap网站开发价格
  • 【实时Linux实战系列】实时应用的多版本共存与无缝升级
  • Linux小课堂: 文件操作核心命令深度解析(cp、mv 与 rm 命令)
  • 【大模型小实验】考一考qwen3-8b对于历史人物的理解
  • 商家建设网站的好处公司单页设计
  • 鹿泉区住房建设局网站网站建设公司 项目经理 的工作指责
  • 字体设计网站有哪些免费网站模块在线制作教程
  • YOLOv3
  • 腾讯元宝-Deepseek 的文章摘要功能测试