当前位置: 首页 > news >正文

Transformer Masked loss原理精讲及其PyTorch逐行实现

Masked Loss 的核心原理是:在计算损失函数时,只考虑真实有意义的词元(token),而忽略掉为了数据对齐而填充的无意义的填充词元(padding token)。

这是重要的技术,可以确保模型专注于学习有意义的任务,并得到一个正确的性能评估。

1.原理精讲

为什么需要 Masked Loss?

在训练神经网络时,我们通常会用一个批次(batch)的数据进行训练,而不是一次只用一个样本。对于自然语言处理任务,我们会一次性处理多句话。但这些句子的长度都几乎不一样。

例如,我们有一个包含两个句子的批次:

["我", "是", "学生"] (长度为 3)

["今天", "天气", "真", "好"] (长度为 4)

为了将它们放入一个统一的张量(tensor)中进行高效的并行计算,我们必须将较短的句子“填充”到一个统一的长度(通常是这个批次中最长句子的长度)。我们会使用一个特殊的 <pad> 词元来完成这个任务。

填充后的数据就变成了:

["我", "是", "学生", "<pad>"]

["今天", "天气", "真", "好"]

现在,问题来了。当模型在训练时,它会为每个位置都生成一个预测。对于第一句话,它也会尝试在第4个位置预测 <pad>。如果我们不加处理,损失函数就会计算模型预测 <pad> 的准确度,并把这个“误差”也算进总的损失里。

这样做有两个坏处:

  1. 浪费计算资源:强迫模型去学习一个无意义的任务——“在句子末尾预测填充符”。

  2. 评估指标失真:这个无意义任务的损失会“稀释”我们真正关心的、对真实词元的预测损失,导致我们无法准确评估模型的真实性能。

Masked Loss 就是为了解决这个问题而生的。它的目标就是创建一个“掩码(mask)”,告诉损失函数不计算PAD。


PyTorch 逐行实现

在 PyTorch 中,实现 Masked Loss 非常简单,因为 nn.CrossEntropyLoss 已经内置了处理它的高效方法。

我们将一步步模拟这个过程。

第零步:准备工作

我们先导入库,并设定一些基本参数。

import torch
import torch.nn as nn#设定参数BATCH_SIZE = 2      # 一个批次里有2句话SEQ_LEN = 5         # 统一填充后的句子长度是5VOCAB_SIZE = 10     # 假设我们的词汇表很小,只有10个词PADDING_IDX = 0     # 我们约定,ID为0的词元就是 <pad> 填充符

代码解释: 我们设定了一个场景:一个批次包含2个句子,每个句子被填充到长度5,词汇表共10个词,并且我们用 0 来代表 <pad>

第一步:模拟模型输出和真实标签

我们创建两个张量:一个是模型预测的 logits,另一个是带填充的真实标签 target

# 模拟模型的原始输出 (logits)
# 形状: (批量大小, 序列长度, 词汇表大小)
logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)# 模拟真实的标签 (ground truth)
# 注意其中包含了 PADDING_IDX (0)
target = torch.tensor([[1, 5, 4, 2, PADDING_IDX],  # 第1句话,最后一个是padding[3, 8, 7, PADDING_IDX, PADDING_IDX]   # 第2句话,最后两个是padding
])print("模型预测 Logits 的形状:", logits.shape)
print("真实标签 Target 的形状:", target.shape)
print("真实标签内容:\n", target)

代码解释logits 是模型对每个位置、每个词的预测得分。target 是我们的“标准答案”,可以看到,为了对齐,较短的句子末尾被填充了 0

第二步:定义损失函数 

# 定义交叉熵损失函数
# 关键:告诉损失函数,所有标签值为 PADDING_IDX 的位置都被忽略criterion = nn.CrossEntropyLoss(ignore_index=PADDING_IDX)

ignore_index=PADDING_IDX 这个参数就是实现 Masked Loss 的方法。当我们把 padding_idx (这里是0) 传给它,CrossEntropyLoss 在内部计算时,会自动跳过所有目标标签是 0 的位置。

第三步:调整张量形状

CrossEntropyLoss 期望的输入形状是:Input: (N, C)Target: (N),其中 N 是样本总数,C 是类别数。而我们现在的 logitstarget 都是二维的批次数据,需要调整一下。

# CrossEntropyLoss 需要的输入形状是 (N, C)
# N 是总的需要计算的元素数量, C是类别数 (即词汇表大小)
# 我们用 .view() 来重塑张量# 将 logits 从 (2, 5, 10) 变为 (10, 10)
reshaped_logits = logits.view(-1, VOCAB_SIZE)# 将 target 从 (2, 5) 变为 (10)
reshaped_target = target.view(-1)print("\n重塑后的 Logits 形状:", reshaped_logits.shape)
print("重塑后的 Target 形状:", reshaped_target.shape)

代码解释: 我们把 (BATCH_SIZE, SEQ_LEN) 这两个维度“压平”成一个维度。-1 是一个占位符,告诉 PyTorch 自动计算这个维度的大小(在这里就是 2 * 5 = 10)。

第四步:计算损失

现在,所有准备工作都已就绪,我们可以直接计算损失。

# 计算损失
# criterion 会自动使用我们设置的 ignore_index=0 来忽略填充位置
loss = criterion(reshaped_logits, reshaped_target)print(f"\n计算出的 Masked Loss 是: {loss.item()}")

代码解释: 尽管 reshaped_target 中仍然包含 0,但由于我们在第二步中设置了 ignore_index=0,这些位置的损失不会被计算和累加。最终得到的 loss 值,是只基于那 7 个真实词元([1, 5, 4, 2][3, 8, 7])计算出来的平均损失。

这样,我们就用非常简洁的方式实现了 Masked Loss。

http://www.dtcms.com/a/295907.html

相关文章:

  • 【Spring Cloud Gateway 实战系列】高级篇:服务网格集成、安全增强与全链路压测
  • 在 Alpine Linux 中创建虚拟机时 Cgroup 挂在失败的现象
  • spring/springboot SPI(二)配合使用的接口
  • 用 AI 破解数据质量难题:从缺失值填补到动态监控的高效解决方案
  • 数据所有权与用益权分离:数字经济时代的权利博弈与“商业机遇”
  • element-plus 组件 ElMessage、ElLoading 弹框 和加载css 样式展示异常总结
  • 【数学,放缩,基本不等式】基本不等式题目
  • TDengine 转化类函数 CAST 用户手册
  • SpringBoot复习
  • Flink-1.19.0源码详解8-ExecutionGraph生成-前篇
  • 洛谷刷题7.24
  • CellFlow:Flow matching建模cell状态变化
  • 如何将拥有的域名自定义链接到我的世界服务器(Minecraft服务器)
  • 大数据集分页优化:LIMIT OFFSET的替代方案
  • Oracle国产化替代:一线DBA的技术决策突围战
  • 如何判断钱包的合约签名是否安全?
  • MySQL深度理解-MySQL索引优化
  • 数据库第一章练习题(大雪圣期末参考复习)
  • 【数据结构】二叉树进阶算法题
  • MinIO 版本管理实践指南(附完整 Go 示例)
  • 一次粗心导致的bug定位
  • 《C++ string 完全指南:string的模拟实现》
  • rust-枚举
  • 开源链动2+1模式AI智能名片S2B2C商城小程序的场景体验分析
  • HBase + PostgreSQL + ElasticSearch 联合查询方案
  • vue3 el-table 列数据合计
  • MongoDB 副本集搭建与 Monstache 实时同步 Elasticsearch 全流程教程
  • AI开放课堂:钉钉MCP开发实战
  • 【DBeaver 安装 MongoDB 插件】
  • 推荐系统如何开发