当前位置: 首页 > news >正文

【深度学习】深入理解交叉熵损失函数 (Cross-Entropy Loss Function)

深入理解交叉熵损失函数 (Cross-Entropy Loss Function)

在机器学习和深度学习领域,损失函数(Loss Function)是衡量模型预测值与真实值之间差异的关键工具。而交叉熵损失函数(Cross-entropy Loss)则是分类问题中最常用、也是最重要的损失函数之一。本文将深入浅出地介绍交叉熵损失函数的定义、原理、不同形式及其在实践中的应用。

文章目录

  • 深入理解交叉熵损失函数 (Cross-Entropy Loss Function)
    • 1 什么是交叉熵?
    • 2 如何在分类问题中使用交叉熵来定义损失函数?
    • 3 交叉熵损失的两种主要形式
      • 3.1 二元交叉熵损失 (Binary Cross-Entropy Loss)
      • 3.2 分类交叉熵损失 (Categorical Cross-Entropy Loss)
      • 3.3 稀疏分类交叉熵 (Sparse Categorical Cross-Entropy)
      • 3.4 大型语言模型(LLM)中的交叉熵损失函数
        • 3.4.1 LLM的核心任务是预测下一个词元(Token)
        • 3.4.2 LLM中的交叉熵损失:如何衡量“预测得有多准”
        • 3.4.3 对于一个完整文本序列的损失
    • 4 交叉熵与最大似然估计的关系
    • 5 总结

1 什么是交叉熵?

在信息论中,交叉熵是用来衡量两个概率分布之间差异的指标。假设有一个真实概率分布 p 和一个预测概率分布 q,交叉熵表示用预测分布 q 来编码真实分布 p 中的事件所需的平均比特数。

交叉熵的计算公式如下:

H(p,q)=−∑ip(xi)log⁡(q(xi))H(p, q) = - \sum_{i} p(x_i) \log(q(x_i))H(p,q)=ip(xi)log(q(xi))

其中,p(xi)p(x_i)p(xi) 是事件 xix_ixi 在真实分布中的概率,q(xi)q(x_i)q(xi) 是事件 xix_ixi 在预测分布中的概率。

当预测分布 q 与真实分布 p 完全相同时,交叉熵达到最小值,这个最小值就是 p 的信息熵。在机器学习中,我们正是利用这一特性,通过最小化交叉熵损失,来使得模型的预测概率分布尽可能地接近真实的标签分布。

2 如何在分类问题中使用交叉熵来定义损失函数?

在分类任务中,我们的目标是让模型输出的概率分布尽可能地接近真实标签的概率分布。对于一个样本,其真实标签可以看作是一个one-hot编码的概率分布。例如,在一个三分类问题中,如果一个样本的真实类别是第二类,那么其真实概率分布就是 [0, 1, 0]

假设模型的输出经过Softmax激活函数后,得到的预测概率分布为 [0.2, 0.7, 0.1]。我们可以看到,这个预测分布与真实分布已经比较接近了。

交叉熵损失函数能够很好地惩罚那些“自信的错误预测”。

  • 当预测正确时:如果模型对正确类别的预测概率很高(例如0.9),那么根据交叉熵的计算公式,−1⋅log⁡(0.9)-1 \cdot \log(0.9)1log(0.9) 的值会很小,损失也就很小。
  • 当预测错误时:如果模型对正确类别的预测概率很低(例如0.1),那么 −1⋅log⁡(0.1)-1 \cdot \log(0.1)1log(0.1) 的值会很大,损失也就会很大。这种大的损失会促使模型在反向传播过程中进行更大幅度的参数调整。

相比于均方误差(MSE)等损失函数,交叉熵在分类问题中通常能提供更强的梯度信号,从而加速模型的收敛。

3 交叉熵损失的两种主要形式

根据分类任务的类型,交叉熵损失函数主要分为两种形式:二元交叉熵(Binary Cross-Entropy)和分类交叉熵(Categorical Cross-Entropy)。

3.1 二元交叉熵损失 (Binary Cross-Entropy Loss)

二元交叉熵主要用于二分类问题,即每个样本只有两个类别(例如,是/否,垃圾邮件/非垃圾邮件)。在这种情况下,模型的输出层通常只有一个神经元,并使用Sigmoid激活函数将其输出值映射到 (0, 1) 区间,表示样本属于正类的概率。

二元交叉熵的公式为:

L=−1N∑i=1N[yilog⁡(pi)+(1−yi)log⁡(1−pi)]L = - \frac{1}{N} \sum_{i=1}^{N} [y_i \log(p_i) + (1 - y_i) \log(1 - p_i)]L=N1i=1N[yilog(pi)+(1yi)log(1pi)]

其中:

  • NNN 是样本总数。
  • yiy_iyi 是第 iii 个样本的真实标签(0或1)。
  • pip_ipi 是模型预测第 iii 个样本为正类(类别1)的概率。

这个公式可以直观地理解:

  • 当真实标签 yi=1y_i=1yi=1 时,损失项为 −log⁡(pi)- \log(p_i)log(pi)。为了最小化损失,模型需要让 pip_ipi 趋近于1。
  • 当真实标签 yi=0y_i=0yi=0 时,损失项为 −log⁡(1−pi)- \log(1-p_i)log(1pi)。为了最小化损失,模型需要让 pip_ipi 趋近于0。

二元交叉熵也常用于多标签分类问题,即一个样本可以同时属于多个类别。在这种情况下,每个类别都被视为一个独立的二分类问题。

3.2 分类交叉熵损失 (Categorical Cross-Entropy Loss)

分类交叉熵用于多分类问题,其中每个样本只属于一个类别(类别之间互斥)。例如,手写数字识别(0-9)。在这种情况下,模型的输出层通常有C个神经元(C为类别数),并使用Softmax激活函数将输出转换为一个概率分布,所有类别的概率之和为1。

分类交叉熵的公式为:

L=−1N∑i=1N∑j=1Cyijlog⁡(pij)L = - \frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \log(p_{ij})L=N1i=1Nj=1Cyijlog(pij)

其中:

  • NNN 是样本总数。
  • CCC 是类别总数。
  • yijy_{ij}yij 是一个符号函数(0或1),如果样本 iii 的真实类别是 jjj,则为1,否则为0(即one-hot编码)。
  • pijp_{ij}pij 是模型预测样本 iii 属于类别 jjj 的概率。

由于 yijy_{ij}yij 是one-hot编码,对于每个样本 iii,只有一个类别 jjj 对应的 yijy_{ij}yij 为1,其他都为0。因此,上述公式可以简化为仅计算真实类别对应的预测概率的负对数。

3.3 稀疏分类交叉熵 (Sparse Categorical Cross-Entropy)

在实践中,如果你的真实标签是整数形式(例如 [1, 2, 0]),而不是one-hot编码形式(例如 [[0,1,0], [0,0,1], [1,0,0]]),使用稀疏分类交叉熵会更加方便,因为它可以在内部处理整数标签到one-hot编码的转换,从而节省内存和计算量。其本质和分类交叉熵是一样的。

3.4 大型语言模型(LLM)中的交叉熵损失函数

3.4.1 LLM的核心任务是预测下一个词元(Token)

要理解LLM的损失函数,首先要明白LLM在训练时最核心的任务是什么:预测下一个词元(Token)

你可以把一个庞大的、训练有素的LLM想象成一个终极“接头霸王”或“文字联想大师”。你给它一段文字,它的工作就是预测在这段文字之后,最可能出现的下一个词元是什么。

例如,当你输入:
"今天天气真不错,我们一起去公园"

模型在训练时,看到 ...去 这个词之后,它的目标就是预测出 公园

3.4.2 LLM中的交叉熵损失:如何衡量“预测得有多准”

LLM的交叉熵损失,正是用来衡量模型在“预测下一个词元”这个任务上表现得有多好的标尺。它具体是这样工作的:

1. 模型的预测 (Predicted Distribution)

当模型接收到一段输入文本(比如 ...我们一起去)后,它会对词汇表(Vocabulary)中所有可能的下一个词元进行打分。这个词汇表可能包含数万甚至数十万个词元(单词、字、标点符号等)。

这个原始的“打分”(通常称为logits)会经过一个 Softmax 函数,转换成一个概率分布。

例如,对于输入 ...我们一起去,模型输出的概率分布可能是这样的:

  • P("公园") = 0.65
  • P("散步") = 0.15
  • P("吃饭") = 0.05
  • P("学习") = 0.001
  • … (其他几万个词元的概率) …

这个由Softmax生成的、包含词汇表中所有词元概率的列表,就是交叉熵公式中的 预测概率分布 q

2. 真实答案 (True Distribution)

在训练数据中,我们明确知道正确的下一个词元是什么。在这个例子里,真实的下一个词元是 公园

为了与模型的概率分布进行比较,我们需要把这个“真实答案”也表示成一个概率分布。这里使用的就是 One-Hot 编码。假设词汇表有50000个词元,而公园是第2048个词元,那么真实的概率分布就是一个长度为50000的向量,其中:

  • 第2048个元素为 1 (代表100%的概率)
  • 所有其他49999个元素都为 0

这个One-Hot向量,就是交叉熵公式中的 真实概率分布 p

3. 计算损失

现在我们有了两个分布:模型的预测分布 q 和真实的one-hot分布 p。我们可以套用分类交叉熵的公式:

L=−∑j=1Cyjlog⁡(pj)L = - \sum_{j=1}^{C} y_{j} \log(p_{j})L=j=1Cyjlog(pj)

其中,CCC 是词汇表的大小,yjy_jyj 是真实分布(one-hot向量)的第j个元素,pjp_jpj 是模型预测分布的第j个元素。

由于真实分布 p 是一个 one-hot 向量,它只有一个位置是1,其他全是0。这使得整个求和计算被极大地简化了:最终的损失值只取决于模型为那个“正确答案”所预测的概率

L=−1⋅log⁡(P模型(正确词元))L = -1 \cdot \log(P_{模型}(\text{正确词元}))L=1log(P模型(正确词元))

  • 如果模型预测很准:比如模型预测 P("公园") = 0.9。损失就是 -log(0.9) ≈ 0.105,这是一个很小的损失值。
  • 如果模型预测很差:比如模型预测 P("公园") = 0.0001。损失就是 -log(0.0001) ≈ 9.21,这是一个非常大的损失值。

这个巨大的损失会产生强大的梯度信号,在反向传播时,促使模型大幅度调整其内部参数,以便下次遇到类似情况时,能给 公园 这个词元更高的预测概率。

3.4.3 对于一个完整文本序列的损失

在实际训练中,LLM并不仅仅预测一个词元。模型会处理一整段文本,并计算每一个位置上预测下一个词元的损失,然后将这些损失平均起来,作为整个序列的最终损失。

例如,对于句子 The cat sat on the mat

  1. 输入 <start>,模型需要预测 The。计算损失 L1L_1L1
  2. 输入 The,模型需要预测 cat。计算损失 L2L_2L2
  3. 输入 The cat,模型需要预测 sat。计算损失 L3L_3L3
  4. 输入 The cat sat,模型需要预测 on。计算损失 L4L_4L4
  5. …以此类推。

整个句子的总损失就是这些单独损失的平均值:

Ltotal=1n∑i=1nLiL_{total} = \frac{1}{n} \sum_{i=1}^{n} L_iLtotal=n1i=1nLi

当讨论训练LLM时所说的“损失”(Loss)或“困惑度”(Perplexity,是交叉熵损失的指数形式,elosse^{loss}eloss),指的就是这个衡量模型预测文本能力的核心指标。

4 交叉熵与最大似然估计的关系

从另一个角度看,最小化交叉熵损失等价于最大化对数似然函数(Log-Likelihood)。在统计学中,最大似然估计是一种通过最大化观测数据出现的概率来估计模型参数的方法。

对于分类问题,模型的似然函数是所有样本被正确分类的联合概率。通过取对数并取负号,就可以将其转换为一个最小化问题,而这个形式恰好就是交叉熵损失函数。因此,当我们训练模型以最小化交叉熵损失时,实际上是在寻找能够使训练数据出现概率最大的模型参数。

5 总结

交叉熵损失函数是深度学习分类任务中不可或缺的一部分。它源于信息论,通过衡量预测概率分布与真实概率分布的差异来指导模型的学习。其优雅的数学形式、强大的梯度信号以及与最大似然估计的深刻联系,使其成为分类问题中的标准和首选损失函数。无论是进行二分类、多分类还是多标签分类,理解并正确使用二元交叉熵和分类交叉熵都是构建高性能模型的关键一步。

http://www.dtcms.com/a/303301.html

相关文章:

  • Lambda表达式Stream流-函数式编程-java8函数式编程(Lambda表达式,Optional,Stream流)从入门到精通-最通俗易懂
  • React与Rudex的合奏
  • 代码解读:微调Qwen2.5-Omni 实战
  • 从单枪匹马到联盟共生:白钰玮的 IP 破局之路|创客匠人
  • 2025创始人IP如何破局?
  • 很妙的一道题 Leetcode234. 回文链表
  • windows部署ACE-Step记录
  • 从一起知名线上故障,谈配置灰度发布的重要性
  • 大模型的开发应用(十九):多模态模型基础
  • 源代码管理工具有哪些?有哪些管理场景?
  • React面试题
  • 2025年SDK游戏盾终极解析:重新定义手游安全的“隐形护甲”
  • 【Linux操作系统】简学深悟启示录:Linux环境基础开发工具使用
  • 浅谈面试中的递归算法
  • 进程通信————system V 消息队列 信号量
  • 卡内基梅隆大学提出Human2LocoMan:基于人类预训练的四足机器人「多功能操作学习框架」
  • sqlite3学习---基础知识、增删改查和排序和限制、打开执行关闭函数
  • AAAI 2025多模态重大突破:SENA框架重塑多模态学习,零标注实现自进化
  • 【Python】—— 语法糖
  • 求两数之和
  • R语言与作物模型(以DSSAT模型为例)融合应用高级实战技术
  • window显示驱动开发—Direct3D 11 视频设备驱动程序接口 (DDI)
  • 图片上传 el+node后端+数据库
  • 数据库事务中的陷阱:脏读、幻读与不可重复读
  • 第四章:分析 Redis 性能高原因和核心字符串类型命令
  • 特性阻抗的近似计算
  • 【Linux】协议——TCP/IP协议
  • PTX指令集基础以及warp级矩阵乘累加指令介绍
  • 5G MBS(组播广播服务)深度解析:从标准架构到商用实践
  • 机器学习(重学版)基础篇(算法与模型一)