当前位置: 首页 > news >正文

Seq2Seq - CrossEntropyLoss细节讨论

在 PyTorch 中,损失函数 CrossEntropyLoss 的输入参数通常需要满足特定的形状要求。对于 CrossEntropyLoss,输入参数的形状要求如下:

  1. input:模型的输出,形状为 [N, C],其中:

    • N 是样本数量(或展平后的序列长度)。

    • C 是类别数量(目标词汇表的大小)。

  2. target:目标标签,形状为 [N],其中每个元素是一个类别索引(整数)。

在上一节的代码中:

loss = loss_fn(similarities.view(-1, len(cn_vocab)), batch_labels.view(-1))

similaritiesbatch_labels 需要被调整为上述形状,以便符合 CrossEntropyLoss 的输入要求。以下详细解释为什么要这样写:

1. similarities.view(-1, len(cn_vocab))

  • similarities 的原始形状[batch_size, seq_len, output_dim],其中:

    • batch_size 是批次大小。

    • seq_len 是序列长度。

    • output_dim 是目标词汇表的大小(len(cn_vocab))。

  • 目标形状[N, C],其中:

    • N 是展平后的序列长度,即 batch_size * seq_len

    • C 是目标词汇表的大小,即 len(cn_vocab)

  • 操作

    • 使用 view(-1, len(cn_vocab))similarities 展平为二维张量,形状为 [batch_size * seq_len, len(cn_vocab)]

    • 这样,每个时间步的输出都被展平为一个二维张量,每一行表示一个时间步的预测概率分布。

2. batch_labels.view(-1)

  • batch_labels 的原始形状[batch_size, seq_len],其中:

    • batch_size 是批次大小。

    • seq_len 是序列长度。

  • 目标形状[N],其中:

    • N 是展平后的序列长度,即 batch_size * seq_len

  • 操作

    • 使用 view(-1)batch_labels 展平为一维张量,形状为 [batch_size * seq_len]

    • 这样,每个时间步的目标标签都被展平为一个一维张量,每个元素是一个类别索引。

3. 为什么这样写

  • 符合 CrossEntropyLoss 的输入要求

    • CrossEntropyLoss 要求输入的预测概率分布是一个二维张量 [N, C],其中每一行表示一个样本的预测概率分布。

    • 目标标签是一个一维张量 [N],其中每个元素是一个类别索引。

  • 处理序列数据

    • 在序列到序列的任务中,每个时间步都有一个预测和一个目标标签。

    • 通过展平操作,可以将所有时间步的预测和目标标签合并为一个批次,从而一次性计算整个批次的损失。

示例

假设:

  • batch_size = 2

  • seq_len = 3

  • output_dim = 5(目标词汇表大小)

原始数据:

similarities: [2, 3, 5]  # [batch_size, seq_len, output_dim]
batch_labels: [2, 3]     # [batch_size, seq_len]

经过 view 操作后:

similarities.view(-1, 5): [6, 5]  # [batch_size * seq_len, output_dim]
batch_labels.view(-1): [6]       # [batch_size * seq_len]

这样,similarities 的每一行表示一个时间步的预测概率分布,batch_labels 的每个元素是一个类别索引,完全符合 CrossEntropyLoss 的输入要求。

http://www.dtcms.com/a/122718.html

相关文章:

  • 深入理解 Vuex:核心概念、API 详解与最佳实践
  • 网络安全应急响应-启动项和任务计划排查
  • 2. git init
  • 探索生成式AI在游戏开发中的应用——3D角色生成式 AI 实现
  • 华为数字芯片机考2025合集3已校正
  • 今天你学C++了吗?——set
  • 深入浅出SPI通信协议与STM32实战应用(W25Q128驱动)(实战部分)
  • 思维森林理论(Cognitive Forest Theory)重构医疗信息系统集群路径探析
  • VectorBT量化入门系列:第三章 VectorBT策略回测基础
  • 【AI News | 20250409】每日AI进展
  • Pyppeteer实战:基于Python的无头浏览器控制新选择
  • React十案例下
  • Java基础第19天-MySQL数据库
  • IT+开发+业务一体化:AI驱动的ITSM解决方案Jira Service Management价值分析(文末免费获取报告)
  • 云轴科技ZStackCTO王为:AI Infra平台具备解决私有化AI全栈难题能力
  • 超便捷超实用的文档处理工具,PDF排序,功能强大,应用广泛,无需下载,在线使用,简单易用快捷!
  • 【JSON2WEB】16 login.html 登录密码加密传输
  • IDEA 调用 Generate 生成 Getter/Setter 快捷键
  • AWS Bedrock生成视频详解:AI视频创作新时代已来临
  • 【零基础实战】Ubuntu搭建DVWA漏洞靶场全流程详解(附渗透测试示例)
  • Java常用工具算法-5--哈希算法、加密算法、签名算法关系梳理
  • 蓝桥杯 B3620 x 进制转 10 进制
  • 【蓝桥杯】15届JAVA研究生组F回文字符串
  • STM32单片机入门学习——第29节: [9-5] 串口收发HEX数据包串口收发文本数据包
  • C++设计模式+异常处理
  • 21 天 Python 计划:MySQL 数据库初识
  • LangChain使用大语言模型构建强大的应用程序
  • 开源模型应用落地-模型上下文协议(MCP)-从数据孤岛到万物互联(一)
  • Linux 实时查看 CUDA 显卡的使用情况命令
  • 基于形状补全和形态测量描述符的腓骨游离皮瓣下颌骨重建自动规划|文献速递-深度学习医疗AI最新文献