当前位置: 首页 > news >正文

交叉熵损失函数(Cross-Entropy Loss)个人理解

作为交叉熵损失函数的笔记,方便理解与复习。

其实我在之前学习KL散度的时候,就已经大概讲解了交叉熵是什么,而交叉熵损失函数,实际上就是用交叉熵作为损失函数。(个人KL散度的链接:点我跳转,有关交叉熵的部分在第5点补充说明部分)

所以这篇文章会额外拓展一些与交叉熵有关的其他内容。

之前总结的内容精简如下:

  • 信息量I(x) = -\log P(x)

  • :表示基于真实分布 P 进行最优编码时,所需的最短平均信息量。
    H(P) = -\sum P(i) \log P(i)

  • 交叉熵:表示使用估计的分布 Q 来对源自真实分布 P 的数据进行编码所需的平均信息量。
    H(P,Q) = -\sum P(i) \log Q(i)

  • KL散度:衡量两个分布 P 和 Q 之间的差异。
    D_{\text{KL}}(P \mid\mid Q) = H(P,Q) - H(P)

一.定义

交叉熵损失函数主要用于衡量两个概率分布之间的差异,在机器学习中,尤其是在分类任务中,我们用它来比较模型的预测分布与数据的真实分布之间的差距。

离散情况下的公式为:

$Loss = -\sum_{i=1}^{C} y_i \log(p_i)$

其中:

  • C:类别的总数。

  • y_i​:一个独热编码向量中第 i 个位置的值。对于真实类别,y_i=1;对于其他类别,y_i=0

  • p_i:模型预测的该样本属于第 i 个类别的概率。

由于 y 是独热编码,只有一个位置是 1(假设是第 k 个位置),其他都是 0,所以在实际计算损失时,这个求和公式可以大大简化为:

Loss = -\log(p_k)

这里的 p_k 就是模型预测的真实类别所对应的概率,整个式子也化简成了目标类别的信息量

PS:独热编码(One-Hot Encoding

独热编码就是用只有一个1、其余全0的向量来表示某个类别。

假设有 3 个类别:猫、狗、鸟
则这三个分类分别对应的独热编码为:

  • 猫 → [1, 0, 0]
  • 狗 → [0, 1, 0]
  • 鸟 → [0, 0, 1]

每个向量中只有对应类别的那一维是 1,其他都是 0。

而在分类模型中,对于每一个输入的样本,最终的输出会是一个向量,向量中每一个数值代表该样本属于对应类别的得分(logits),这个得分在经过softmax函数后将会转变为该样本属于对应类别的概率

PS:柔性最大值(soft maximum,softmax

softmax将一个包含任意实数的向量“压缩”为另一个值在 (0, 1) 之间,且所有元素之和为 1 的向量。它放大了最大值的比重,但并不会完全忽略其他较小的值。

其公式为:

$\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$

举个例子:

softmax([1, 2, 3, 4])\approx [0.03, 0.09, 0.24, 0.64]

在这个例子中,最大值4被赋予了最高的概率(0.64),但其他值也获得了非零的概率,概率大小与它们自身的值成比例。

那么既然有柔性最大值,那有没有“硬”最大值呢?当然有,那就是argmax。

argmax(argument of the maximum)会返回最大值所在的索引位置

举个例子 (假设第一个数的索引是0):

argmax([1, 2, 3, 4]) = 3

一般来说,当具有多个最大值时会返回第一个最大值的索引:

argmax([1, 4, 3, 4]) = 1

另外,一般来说,softmax和argmax在代码里都作为函数名使用,不会专门进行翻译。

二.在pytorch的理论和实际流程对比

1.理论流程

步骤1:对每个样本应用softmax

$p_i = \frac{\exp(x_i)}{\sum_{j=1}^{C} \exp(x_j)}$

步骤2:由于真实标签是类别索引k,直接选取对应的概率pₖ

步骤3:计算损失

$\text{Loss} = -\log(p_k)$

2.实际流程

Loss = -\log(p_k)\\=-\log(\mathrm{Softmax}(x)_k)\\= -\log\left( \frac{e^{x_k}}{\sum_j e^{x_j}} \right)\\= -\log(e^{x_k}) + \log\left( \sum_j e^{x_j} \right)\\= -x_k + \log\left( \sum_j e^{x_j} \right)\\= -x_k+m + \log\left( \sum_j e^{x_j - m} \right)

其中:

m = \max(x_1, x_2, \ldots, x_C),C代表有C个分类。

PS:为什么要这样计算?

在理论计算流程中,可能出现以下问题:

  • 上溢:如果某个 logit 值很大(例如1000),exp(1000) 会变成一个超出计算机浮点数表示范围的巨大数值,结果为 inf 。

  • 下溢:如果某个 logit 值很小(例如-1000),exp(-1000) 会下溢为0。当这个值出现在Softmax分子中时,可能导致 log(0) 的情况,结果为 nan 。

在实际流程中:

  • 解决上溢:计算 \log\left( \sum_j e^{x_j} \right) 时,先把每一个x_j减去最大值 m = max(logits),使得 exp(x_j - m)的最大值为1,所有指数项都 ≤ 1,彻底杜绝了上溢。

  • 解决下溢:即使某些 exp(x_j - m) 下溢为0,它们在求和时也被有效地视为0,而 \log\left( \sum_j e^{x_j - m} \right) 仍然是一个有效的数值,因为求和时至少包含一个来自最大值自身的e^{m - m}=1,这保证了计算的连续性,不会产生 nan。

三.代码中的实现

pytorch在库中封装了 nn.LogSoftmax 与 nn.NLLLoss 两个函数,计算交叉熵损失函数的过程,其实就是分别执行 nn.LogSoftmax 与 nn.NLLLoss 两个函数的过程。

\mathrm{LogSoftmax}(x_k) = x_k - \left[ m + \log\left( \sum_j e^{x_j - m} \right) \right]

\mathrm{NLLLoss} = -\mathrm{LogSoftmax}(x_k) = -x_k + m + \log\left( \sum_j e^{x_j - m} \right)

需要澄清的是,实际上交叉熵损失函数 nn.CrossEntropyLoss 不是调用这两个函数实现的,只是逻辑上是这样的,并且这两个函数也有实际的封装。

# 直接使用(推荐)
loss = nn.CrossEntropyLoss()(logits, targets)# 模块组合(需要自定义流程时)
log_probs = nn.LogSoftmax(dim=1)(logits)
loss = nn.NLLLoss()(log_probs, targets)# 两者是等价的,但是CrossEntropyLoss是优化实现,通常更快

举个最简单的使用示例,只需要直接调用 nn.CrossEntropyLoss 即可:

import torch
import torch.nn as nn
import torch.optim as optim# 1. 准备数据
X = torch.tensor([[1.0, 2.0], [2.0, 1.0], [3.0, 4.0], [4.0, 3.0]])  # 4个样本,每个样本2个特征
y = torch.tensor([0, 0, 1, 1])  # 4个标签(2分类)# 2. 定义模型
model = nn.Sequential(nn.Linear(2, 10),  # 输入2维,隐藏层10维nn.ReLU(),nn.Linear(10, 2)   # 输出2维(2分类)
)# 3. 定义损失函数和优化器
# 只需要在这里调用交叉熵损失函数即可
criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降# 4. 训练测试
def train(model, X, y, criterion, optimizer, epochs):pass
def test(model, X, y, criterion):pass
train(model, X, y, criterion, optimizer, 100)
test(model, X, y, criterion)

四.总结

• 作用:衡量预测概率与真实分布的差异
• 计算:Loss = -log(正确类别的预测概率)
• 实现:PyTorch中直接调用 nn.CrossEntropyLoss()

http://www.dtcms.com/a/598953.html

相关文章:

  • 结对编程:提升编程效率与团队协作的最佳实践 | 如何通过结对编程实现高效协作和代码质量提升
  • 缓存优化(SpringCache、XXL-JOB)
  • 网站建设长期待摊费用个人网站的留言板怎么做
  • 优惠劵网站怎么做srm系统
  • Hugging Face Gated 模型下载全攻略:解决 401/403 和访问受限问题
  • 建筑行业网站模板ajax实现wordpress导航栏
  • 网站建设服务 杭州甜品店网页模板html
  • 状态机的实现方法--C语言版本
  • 网站做app开发有梦商城公司网站
  • 网站开发系统毕业综合实践报告电子版个人简历模板
  • 线代强化NO5|矩阵的运算法则|分块矩阵|逆矩阵|伴随矩阵|初等矩阵
  • 最新域名网站查询网站背景大小
  • 服装网站建设发展状况wordpress数据库访问慢
  • 大同市住房城乡建设网站扬州网站建设 天维
  • nat123做网站 查封编写网站的软件
  • 天津房地产网站建设福建联美建设集团有限公司网站
  • 简述网站建设有哪些步骤有什么网站可以做推广
  • C语言进阶:位操作
  • 建站网站苏州wordpress架设系统
  • wordpress短代码返回html石家庄网站seo优化
  • python合适做网站吗网站建设与维护面试
  • 什么是Hinge损失函数
  • 网站设计的趋势百度双站和响应式网站的区别
  • usrsctp之cookie
  • CC防护:抵御应用层攻击的精确防线
  • 如何自己制作链接内容泰安网站建设优化
  • 芜湖哪里做网站亚马逊雨林的资料
  • Manus高精度动捕数据手套,Metagloves Pro对比Quantum Metagloves:谁是你的灵巧手研发最佳选择?
  • 佛山网站建设3lue3lue修改图片网站
  • 【开题答辩实录分享】以《中医古籍管理系统》为例进行答辩实录分享