当前位置: 首页 > news >正文

PyTorch损失函数全解析与实战指南

以下是更详细的 ​PyTorch 损失函数解析,涵盖数学公式、适用场景、优缺点对比及代码示例:


1. 回归损失函数(Regression Losses)​

​(1) L1Loss (Mean Absolute Error, MAE)​
  • 公式:
  • 特点:
    • 对异常值鲁棒(因为绝对值函数梯度恒定)。
    • 梯度不随误差大小变化,可能导致收敛慢。
  • 适用场景:
    • 目标变量存在异常值时(如房价预测)。
    • 需要稳定梯度的场景。
​(2) MSELoss (Mean Squared Error, MSE)​
  • 公式:
  • 特点:
    • 对误差大的样本惩罚更重(平方放大了异常值影响)。
    • 梯度随误差增大而增大,可能收敛更快。
  • 适用场景:
    • 数据干净且异常值较少时(如物理实验预测)。
    • 需要强信号指导优化的任务。
​(3) SmoothL1Loss
  • 公式:
  • 特点:
    • 结合MSE和L1的优点:小误差时平滑,大误差时鲁棒。
    • 避免MSE的梯度爆炸问题。
  • 适用场景:
    • 目标检测中的边界框回归(如Faster R-CNN)。
    • 需要平衡精度和鲁棒性的任务。
​(4) HuberLoss (Smooth Mean Absolute Error)​
  • 公式:
  • 特点:
    • 可调节超参数δ控制鲁棒性。
    • 比SmoothL1Loss更灵活。
  • 适用场景:
    • 需要动态调整鲁棒性的回归任务(如金融预测)。
​(5) KLDivLoss (Kullback-Leibler Divergence)​
  • 公式:
  • 特点:
    • 衡量两个概率分布的差异(非对称)。
    • 常用于生成模型(如GANs、VAEs)。
  • 适用场景:
    • 概率分布对齐(如文本生成、图像生成)。

2. 分类损失函数(Classification Losses)​

​(1) NLLLoss (Negative Log-Likelihood)​
  • 公式:(y^​i,c​是样本i属于类别c的概率)
  • 特点:
    • 要求输入是log概率​(通常配合LogSoftmax使用)。
    • 数值稳定(避免log(0)问题)。
  • 适用场景:
    • 多分类任务(如图像分类)。
​(2) BCELoss (Binary Cross-Entropy)​
  • 公式:
  • 特点:
    • 直接处理0-1概率输出(需配合Sigmoid)。
    • 对类别不平衡敏感。
  • 适用场景:
    • 二分类任务(如垃圾邮件检测)。
​(3) BCEWithLogitsLoss
  • 公式:
  • 特点:
    • 内部集成Sigmoid,数值更稳定。
    • 适合输出未归一化的logits。
  • 适用场景:
    • 二分类任务(默认选择,优于BCELoss)。
​(4) CrossEntropyLoss
  • 公式:
  • 特点:
    • 内部集成Softmax和NLLLoss。
    • 直接处理原始logits,无需手动归一化。
  • 适用场景:
    • 多分类任务(如ResNet、ViT)。
​(5) MultiLabelSoftMarginLoss
  • 公式:
  • 特点:
    • 允许一个样本属于多个类别(多标签分类)。
    • 输出为每个类别的概率(0到1之间)。
  • 适用场景:
    • 多标签分类(如图像中同时包含猫和狗)。

3. 其他损失函数

​(1) HingeEmbeddingLoss (用于半监督/度量学习)​
  • 公式:
  • 特点:
    • 鼓励相似样本靠近,不相似样本远离。
    • 常用于Siamese网络(如人脸验证)。
​(2) TripletLoss (用于度量学习)​
  • 公式:(d为距离,a为锚点,p为正样本,n为负样本)
  • 适用场景:
    • 需要学习特征嵌入的任务(如FaceNet人脸识别)。
​(3) CTCLoss (Connectionist Temporal Classification)​
  • 特点:
    • 无需对齐输入输出序列(如语音识别)。
    • 处理不定长序列数据。
  • 适用场景:
    • 序列标注任务(如OCR文字识别)。

以下是更详细的 ​PyTorch 损失函数解析,涵盖数学公式、适用场景、优缺点对比及代码示例:


1. 回归损失函数(Regression Losses)​

​(1) L1Loss (Mean Absolute Error, MAE)​
  • 公式:L=N1​i=1∑N​∣yi​−y^​i​∣
  • 特点:
    • 对异常值鲁棒(因为绝对值函数梯度恒定)。
    • 梯度不随误差大小变化,可能导致收敛慢。
  • 适用场景:
    • 目标变量存在异常值时(如房价预测)。
    • 需要稳定梯度的场景。
​(2) MSELoss (Mean Squared Error, MSE)​
  • 公式:L=N1​i=1∑N​(yi​−y^​i​)2
  • 特点:
    • 对误差大的样本惩罚更重(平方放大了异常值影响)。
    • 梯度随误差增大而增大,可能收敛更快。
  • 适用场景:
    • 数据干净且异常值较少时(如物理实验预测)。
    • 需要强信号指导优化的任务。
​(3) SmoothL1Loss
  • 公式:L={0.5(x)2∣x∣−0.5​if ∣x∣<1otherwise​
  • 特点:
    • 结合MSE和L1的优点:小误差时平滑,大误差时鲁棒。
    • 避免MSE的梯度爆炸问题。
  • 适用场景:
    • 目标检测中的边界框回归(如Faster R-CNN)。
    • 需要平衡精度和鲁棒性的任务。
​(4) HuberLoss (Smooth Mean Absolute Error)​
  • 公式:L={21​(yi​−y^​i​)2δ(∣yi​−y^​i​∣−21​δ)​if ∣yi​−y^​i​∣<δotherwise​
  • 特点:
    • 可调节超参数δ控制鲁棒性。
    • 比SmoothL1Loss更灵活。
  • 适用场景:
    • 需要动态调整鲁棒性的回归任务(如金融预测)。
​(5) KLDivLoss (Kullback-Leibler Divergence)​
  • 公式:L=i=1∑N​yi​log(y^​i​yi​​)
  • 特点:
    • 衡量两个概率分布的差异(非对称)。
    • 常用于生成模型(如GANs、VAEs)。
  • 适用场景:
    • 概率分布对齐(如文本生成、图像生成)。

2. 分类损失函数(Classification Losses)​

​(1) NLLLoss (Negative Log-Likelihood)​
  • 公式:L=−N1​i=1∑N​log(y^​i,c​)(y^​i,c​是样本i属于类别c的概率)
  • 特点:
    • 要求输入是log概率​(通常配合LogSoftmax使用)。
    • 数值稳定(避免log(0)问题)。
  • 适用场景:
    • 多分类任务(如图像分类)。
​(2) BCELoss (Binary Cross-Entropy)​
  • 公式:L=−N1​i=1∑N​[yi​log(y^​i​)+(1−yi​)log(1−y^​i​)]
  • 特点:
    • 直接处理0-1概率输出(需配合Sigmoid)。
    • 对类别不平衡敏感。
  • 适用场景:
    • 二分类任务(如垃圾邮件检测)。
​(3) BCEWithLogitsLoss
  • 公式:L=−N1​i=1∑N​[yi​log(σ(xi​))+(1−yi​)log(1−σ(xi​))](σ(xi​)是Sigmoid函数)
  • 特点:
    • 内部集成Sigmoid,数值更稳定。
    • 适合输出未归一化的logits。
  • 适用场景:
    • 二分类任务(默认选择,优于BCELoss)。
​(4) CrossEntropyLoss
  • 公式:L=−N1​i=1∑N​c=1∑C​yi,c​log(∑j=1C​exi,j​exi,c​​)
  • 特点:
    • 内部集成Softmax和NLLLoss。
    • 直接处理原始logits,无需手动归一化。
  • 适用场景:
    • 多分类任务(如ResNet、ViT)。
​(5) MultiLabelSoftMarginLoss
  • 公式:L=−N1​i=1∑N​c=1∑C​[yi,c​log(σ(xi,c​))+(1−yi,c​)log(1−σ(xi,c​))]
  • 特点:
    • 允许一个样本属于多个类别(多标签分类)。
    • 输出为每个类别的概率(0到1之间)。
  • 适用场景:
    • 多标签分类(如图像中同时包含猫和狗)。

3. 其他损失函数

​(1) HingeEmbeddingLoss (用于半监督/度量学习)​
  • 公式:L={max(0,1−(xi​−xj​))max(0,1+(xi​−xj​))​if yij​=1if yij​=−1​
  • 特点:
    • 鼓励相似样本靠近,不相似样本远离。
    • 常用于Siamese网络(如人脸验证)。
​(2) TripletLoss (用于度量学习)​
  • 公式:L=max(0,d(a,p)−d(a,n)+margin)(d为距离,a为锚点,p为正样本,n为负样本)
  • 适用场景:
    • 需要学习特征嵌入的任务(如FaceNet人脸识别)。
​(3) CTCLoss (Connectionist Temporal Classification)​
  • 特点:
    • 无需对齐输入输出序列(如语音识别)。
    • 处理不定长序列数据。
  • 适用场景:
    • 序列标注任务(如OCR文字识别)。

如何选择损失函数?​

​​​​​​​

任务类型推荐损失函数原因
回归(鲁棒性)SmoothL1Loss/HuberLoss对异常值不敏感
回归(精度)MSELoss误差大时梯度大,收敛快
二分类BCEWithLogitsLoss数值稳定,集成Sigmoid
多分类CrossEntropyLoss直接处理logits,避免手动Softmax
多标签分类MultiLabelSoftMarginLoss允许样本属于多个类别
度量学习TripletLoss/HingeLoss学习样本间相似性
序列标注CTCLoss处理不定长序列
http://www.dtcms.com/a/350661.html

相关文章:

  • 高性能C++实践:原子操作与无锁队列实现
  • C++ #pragma
  • C++初阶(3)C++入门基础2
  • 现代C++工具链实战:CMake + Conan + vcpkg依赖管理
  • MYSQL的bin log是什么
  • JUC并发编程08 - 同步模式/异步模式
  • ROS2 python功能包launch,config文件编译后找不到
  • 链表OJ习题(2)
  • 搭建基于LangChain实现复杂RAG聊天机器人
  • AI在软件研发流程中的提效案例
  • 在vue3后台项目中使用热力图,并给热力图增加点击选中事件
  • Java中删除字符串首字符
  • 【51单片机】【protues仿真】基于51单片机数码管温度报警器系统
  • AR眼镜赋能水利智能巡检的创新实践
  • 算法题打卡力扣第167题:两数之和——输入有序数组(mid)
  • VASP计算层错能(SFE)全攻略2
  • python自学笔记12 NumPy 常见运算
  • QT(1)
  • 独立显卡接口操作指南
  • 小程序开发指南(四)(UI 框架整合)
  • Linux系统网络管理
  • UE5 UI遮罩
  • 人形机器人产业风口下,低延迟音视频传输如何成为核心竞争力
  • Linux笔记9——shell编程基础-3
  • OpenFeign的原理解析
  • FMS回顾和总结
  • C++ 中 `std::map` 的 `insert` 函数
  • 【机器学习项目 心脏病预测】
  • 【广告系列】流量归因模型
  • centos 用 docker 方式安装 dufs