当前位置: 首页 > news >正文

李沐-第十章-训练Seq2SeqAttentionDecoder报错

问题

系统: win11
显卡:5060
CUDA:12.8
pytorch:2.7.1+cu128
pycharm:2025.2

训练带有注意力机制的编码器-解码器网络时,

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

会报错unpack错误.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 118 decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)10 net = d2l.EncoderDecoder(encoder, decoder)
---> 11 train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)Cell In[7], line 41, in train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device)39 bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)40 dec_input = torch.cat([bos, Y[:,:-1]], 1)
---> 41 Y_hat, _ = net(X, dec_input, X_valid_len)42 l = loss(Y_hat, Y, Y_valid_len)43 l.sum().backward()ValueError: too many values to unpack (expected 2)

分析

排查发现d2l包里面的EncoderDecoder定义和教材中不同.
d2l包里面的定义:

class EncoderDecoder(d2l.Classifier):"""The base class for the encoder--decoder architecture.Defined in :numref:`sec_encoder-decoder`"""def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):enc_all_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_all_outputs, *args)# Return decoder output onlyreturn self.decoder(dec_X, dec_state)[0]def predict_step(self, batch, device, num_steps,save_attention_weights=False):"""Defined in :numref:`sec_seq2seq_training`"""batch = [d2l.to(a, device) for a in batch]src, tgt, src_valid_len, _ = batchenc_all_outputs = self.encoder(src, src_valid_len)dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len)outputs, attention_weights = [d2l.expand_dims(tgt[:, 0], 1), ], []for _ in range(num_steps):Y, dec_state = self.decoder(outputs[-1], dec_state)outputs.append(d2l.argmax(Y, 2))# Save attention weights (to be covered later)if save_attention_weights:attention_weights.append(self.decoder.attention_weights)return d2l.concat(outputs[1:], 1), attention_weights

教材的定义(9.6.3):

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super(EncoderDecoder, self).__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

解决方法

使用教材中的EncoderDecoder定义, 正常训练不报错.
训练结果:
在这里插入图片描述

http://www.dtcms.com/a/349386.html

相关文章:

  • 十九、云原生分布式存储 CubeFS
  • 剧本杀APP系统开发:打造多元化娱乐生态的先锋力量
  • Go编写的轻量文件监控器. 可以监控终端上指定文件夹内的变化, 阻止删除,修改,新增操作. 可以用于AWD比赛或者终端应急响应
  • TensorFlow深度学习实战(34)——TensorFlow Probability
  • GO学习记录八——多文件封装功能+redis使用
  • Node.js(2)—— Buffer
  • 安卓Android低功耗蓝牙BLE连接异常报错133
  • Docker Compose 部署 Elasticsearch 8.12.2 集成 IK 中文分词器完整指南
  • Go初级三
  • 上海AI实验室突破扩散模型!GetMesh融合点云与三平面,重塑3D内容创作
  • 少儿舞蹈小程序需求规格说明书
  • AutoCAD Electrical缺少驱动程序“AceRedist“解决方法
  • 【STM32】G030单片机的独立看门狗
  • ELKB日志分析平台 部署
  • 完美世界招数据仓库工程师咯
  • ArcGIS JSAPI 高级教程 - 创建渐变色材质的自定义几何体
  • three.js+WebGL踩坑经验合集(8.3):合理设置camera.near和camera.far缓解实际场景中的z-fighting叠面问题
  • 大数据平台ETL任务导入分库分表数据
  • Jenkins+docker 微服务实现自动化部署安装和部署过程
  • TDengine IDMP 应用场景:电动汽车
  • AI测试工具midsence和browse_use的使用场景和差异
  • react+taro打包到不同小程序
  • Flutter旧版本升级-> Android 配置、iOS配置
  • 机器视觉的3C玻璃盖板丝印应用
  • KeepAlived+Haproxy实现负载均衡(SLB)
  • window显示驱动开发—混合系统 DDI 和 dList DLL 支持
  • Shell 循环编程:for 与 select 轻松入门
  • HTTP 与 HTTPS 深度解析:从原理到实际应用
  • Kubernetes (K8s)入门指南:Docker之后,为什么需要容器编排?
  • 安全合规:AC(上网行为安全)--下