当前位置: 首页 > news >正文

PyTorch中的损失函数

PyTorch 提供了丰富的损失函数用于不同类型的机器学习任务。下面我将全面介绍 PyTorch 中的主要损失函数,包括它们的数学表达式、使用场景和实际代码示例。

一、回归任务损失函数

1. MSELoss (均方误差损失)

torch.nn.MSELoss(reduction='mean')
  • 公式loss = (x - y)²

  • 特点: 对异常值敏感,惩罚大误差更重

  • 应用: 一般回归问题

    criterion = nn.MSELoss()
    loss = criterion(outputs, targets)

 2. L1Loss (平均绝对误差)

torch.nn.L1Loss(reduction='mean')
  • 公式loss = |x - y|

  • 特点: 对异常值更鲁棒

  • 应用: 需要减少异常值影响的回归问题

3. SmoothL1Loss (Huber损失)

torch.nn.SmoothL1Loss(reduction='mean', beta=1.0)

公式

 

  • 特点: 结合L1和L2的优点

  • 应用: 目标检测(如Faster R-CNN)

二、分类任务损失函数

1. CrossEntropyLoss (交叉熵损失)

torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')
  • 公式loss = -log(exp(x[class]) / ∑exp(x[j]))

  • 特点: 自动应用softmax

  • 应用: 多分类问题

    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs, targets)  # targets是类别索引

 2. BCELoss (二元交叉熵)

torch.nn.BCELoss(weight=None, reduction='mean')
  • 公式:
     

  • 要求: 输入需经过sigmoid(0-1之间)

  • 应用: 二分类问题

3. BCEWithLogitsLoss

torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
  • 特点: 结合sigmoid和BCELoss,数值更稳定

  • 应用: 推荐用于二分类问题

三、其他重要损失函数

1. NLLLoss (负对数似然损失)

torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
  • 要求: 输入需经过log-softmax

  • 应用: 通常与LogSoftmax配合使用

2. KLDivLoss (KL散度) 

torch.nn.KLDivLoss(reduction='mean')
  • 公式loss = y * (log(y) - x)

  • 应用: 衡量概率分布差异,如VAE

3. MarginRankingLoss

torch.nn.MarginRankingLoss(margin=0.0, reduction='mean')
  • 应用: 排序任务

4. TripletMarginLoss

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False)
  • 应用: 度量学习,人脸识别

5. CosineEmbeddingLoss 

torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
  • 应用: 相似度学习

四、损失函数选择指南

任务类型推荐损失函数备注
回归问题MSELoss/L1Loss/SmoothL1Loss根据异常值情况选择
二分类BCEWithLogitsLoss优于BCELoss
多分类CrossEntropyLoss最常用
多标签分类BCEWithLogitsLoss每个类别独立判断
分布匹配KLDivLoss如VAE
相似度学习TripletMarginLoss/CosineEmbeddingLoss度量学习

 五、自定义损失函数示例

class CustomLoss(nn.Module):
    def __init__(self, weight=1.0):
        super().__init__()
        self.weight = weight
        
    def forward(self, inputs, targets):
        # 计算L1损失
        l1_loss = torch.abs(inputs - targets)
        # 计算特殊惩罚项
        penalty = torch.where(targets > inputs, 2.0 * l1_loss, l1_loss)
        # 组合损失
        return (penalty.mean() + self.weight * l1_loss.mean())
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/110653.html

相关文章:

  • 【Django】教程-10-ajax请求Demo,结合使用
  • 算法导论(动态规划)——子数组系列
  • 了解Docker容器的常见退出状态码及其含义
  • dify新版本1.1.3的一些问题
  • 生成对抗网络(GAN)详解(代码实现)
  • MySQL 中的 MVCC 版本控制机制原理
  • PCIe初始化Detect状态解读
  • 32f4,usart2fifo,2025
  • 【大模型系列篇】大模型基建工程:基于 FastAPI 自动构建 SSE MCP 服务器
  • 模版进阶(沉淀中)
  • 云原生安全渗透篇
  • 让AI再次伟大-MCP-Client开发指南
  • strace命令详解
  • .NET用C#在PDF文档中添加、删除和替换图片
  • InfluxDB用户管理全攻略:从入门到精通
  • C++ 继承方式使用场景(极简版)
  • fastGPT—nextjs—mongoose—团队管理之部门相关api接口实现
  • 当系统会“说话“:用人类能听懂的方式聊聊Syslog和Kafka
  • 【MongoDB + 向量搜索引擎】MongoDB Atlas 向量搜索 提供全托管解决方案
  • Docker自动部署Spring Boot项目的Shell脚本
  • Caddy 从入门到实战指南(一)
  • 鸿蒙NEXT小游戏开发:井字棋
  • Java学习总结-io流-字节流
  • 基于51单片机的模拟条形码识别系统proteus仿真
  • GitLab CVE-2025-2255 漏洞解决方案
  • 【通知】STM32MP157驱动开发课程全新升级!零基础入门嵌入式Linux驱动,掌握底层开发核心技能!
  • Linux信号——信号的保存(2)
  • HTML5 Video(视频)学习笔记
  • AVR128单片机红外遥控8*8LED点阵屏显示
  • 【python中级】使用 setuptools生成 whl 轮子文件