当前位置: 首页 > news >正文

李沐-第十章-实现Seq2SeqAttentionDecoder时报错

问题

系统: win11
显卡:5060
CUDA:12.8
pytorch:2.7.1+cu128
pycharm:2025.2

实现Seq2SeqAttentionDecoder后, 运行:

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

会报错AdditiveAttention.__init__()输入参数个数不正确:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 31 encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)2 encoder.eval()
----> 3 decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)4 decoder.eval()6 X = torch.zeros((4, 7), dtype=torch.long)Cell In[3], line 4, in Seq2SeqAttentionDecoder.__init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout, **kwargs)2 def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):3     super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
----> 4     self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)5     self.embedding = nn.Embedding(vocab_size, embed_size)6     self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)TypeError: AdditiveAttention.__init__() takes 3 positional arguments but 5 were given

分析

经排查, 是d2l包里面的AdditiveAttention定义和教材中的不一致问题.
d2l包里面的定义:

class AdditiveAttention(nn.Module):"""Additive attention.Defined in :numref:`subsec_batch_dot`"""def __init__(self, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.LazyLinear(num_hiddens, bias=False)self.W_q = nn.LazyLinear(num_hiddens, bias=False)self.w_v = nn.LazyLinear(1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)# After dimension expansion, shape of queries: (batch_size, no. of# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of# key-value pairs, num_hiddens). Sum them up with broadcastingfeatures = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)# There is only one output of self.w_v, so we remove the last# one-dimensional entry from the shape. Shape of scores: (batch_size,# no. of queries, no. of key-value pairs)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)# Shape of values: (batch_size, no. of key-value pairs, value# dimension)return torch.bmm(self.dropout(self.attention_weights), values)

教材中的定义:

class AdditiveAttention(nn.Module):def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)features = queries.unsqueeze(2) + keys.unsqueeze(1)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)

解决方法

使用教材中的AdditiveAttention定义, 粘贴到d2l/pytorch.py, 替换原定义. 可以正常运行.
在这里插入图片描述

http://www.dtcms.com/a/349390.html

相关文章:

  • 什么是事件循环(Event Loop)?浏览器和 Node.js 中的事件循环有什么区别?
  • springboot整合druid(多数据源配置)
  • Python_occ 学习记录 | 阵列
  • 李沐-第十章-训练Seq2SeqAttentionDecoder报错
  • 十九、云原生分布式存储 CubeFS
  • 剧本杀APP系统开发:打造多元化娱乐生态的先锋力量
  • Go编写的轻量文件监控器. 可以监控终端上指定文件夹内的变化, 阻止删除,修改,新增操作. 可以用于AWD比赛或者终端应急响应
  • TensorFlow深度学习实战(34)——TensorFlow Probability
  • GO学习记录八——多文件封装功能+redis使用
  • Node.js(2)—— Buffer
  • 安卓Android低功耗蓝牙BLE连接异常报错133
  • Docker Compose 部署 Elasticsearch 8.12.2 集成 IK 中文分词器完整指南
  • Go初级三
  • 上海AI实验室突破扩散模型!GetMesh融合点云与三平面,重塑3D内容创作
  • 少儿舞蹈小程序需求规格说明书
  • AutoCAD Electrical缺少驱动程序“AceRedist“解决方法
  • 【STM32】G030单片机的独立看门狗
  • ELKB日志分析平台 部署
  • 完美世界招数据仓库工程师咯
  • ArcGIS JSAPI 高级教程 - 创建渐变色材质的自定义几何体
  • three.js+WebGL踩坑经验合集(8.3):合理设置camera.near和camera.far缓解实际场景中的z-fighting叠面问题
  • 大数据平台ETL任务导入分库分表数据
  • Jenkins+docker 微服务实现自动化部署安装和部署过程
  • TDengine IDMP 应用场景:电动汽车
  • AI测试工具midsence和browse_use的使用场景和差异
  • react+taro打包到不同小程序
  • Flutter旧版本升级-> Android 配置、iOS配置
  • 机器视觉的3C玻璃盖板丝印应用
  • KeepAlived+Haproxy实现负载均衡(SLB)
  • window显示驱动开发—混合系统 DDI 和 dList DLL 支持