【深度学习|学习笔记】神经网络中有哪些损失函数?(一)
【深度学习|学习笔记】神经网络中有哪些损失函数?(一)
【深度学习|学习笔记】神经网络中有哪些损失函数?(一)
文章目录
- 【深度学习|学习笔记】神经网络中有哪些损失函数?(一)
- 前言
- 一、总览:按任务范式的损失函数地图
- 二、可直接用的 PyTorch “Loss Zoo” 🧰
欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文。详细信息可扫描博文下方二维码 “
学术会议小灵通
”或参考学术信息专栏:https://blog.csdn.net/2401_89898861/article/details/148877490
前言
- 下面给出一份**“神经网络损失函数全景图 + 可直接复用的 PyTorch 代码库**(Loss Zoo)”。
- 按机器学习范式把常用且实战高频的损失分门别类,并给出可以贴进你项目里的实现骨架(均为纯 PyTorch,无外部依赖)。
注:“所有”在理论上不可能穷尽,但下列覆盖了监督/自监督/生成/度量学习/排序/分割/时序/不确定性/RL 等主流场景中研究与工业最常用的损失与代表性变体。
一、总览:按任务范式的损失函数地图
范式 | 代表问题 | 常用损失(粗体=最常用) |
---|---|---|
回归/密集预测 | 数值回归、深度估计、超分辨 | MSE(L2)、L1/MAE、Huber/ SmoothL1、Log-Cosh、Charbonnier、分位数(Quantile/Pinball)、Gaussian/Laplace NLL、CRPS(概率预报) |
分类(单/多标签) | 图像/文本分类,多标签标签卡 | Cross Entropy(含Label Smoothing)、BCEWithLogits(多标签)、Focal、Asymmetric Loss(ASL)、加权/类平衡 CE、AUC 近似(pairwise) |
结构化输出/分割 | 语义/实例分割、医学影像 | Dice / Soft Dice、Jaccard(IoU)/Lovász-Softmax、Tversky / Focal-Tversky、边界/HD 近似 |
度量学习/检索 | 行人重识别、人脸、检索 | Triplet、Contrastive、InfoNCE/NT-Xent/SupCon、ArcFace/CosFace/AM-Softmax、Circle Loss、Center Loss |
排序/推荐 | Top-K、点击率排序 | Pairwise Hinge/Logistic、BPR、ListNet/ListMLE、LambdaRank 族(NDCG surrogate) |
序列 / 语音 | ASR、OCR、翻译 | CTC、跨熵(自回归)、Label Smoothing、KD |
生成建模 | VAE/GAN/扩散/自回归 | ELBO(重建+KL)、GAN-BCE/Hinge/WGAN-GP、Diffusion MSE/ϵ-loss/v-param、NLL(Flow/Transformer) |
自监督 | 对比/冗余降低/掩码 | InfoNCE/NT-Xent、Barlow Twins、VICReg、BYOL(负样本自由)、MAE/MIM 的 MSE/CE |
图学习 | 链路预测/节点分类 | BCE/对比、Cross Entropy、BPR(图推荐) |
不确定性/鲁棒 | 异方差回归/长尾 | Gaussian/Laplace NLL(预测方差)、分位数损失、GCE/SCE(噪声鲁棒) |
强化学习 | 策略/价值学习 | Policy Gradient + PPO clipped、Value MSE、Entropy Bonus、SAC(actor/critic 温度项) |
其他(正则/辅助) | 不是主任务损失但常组合 | L2/L1 权重、Orthogonality、Center、Feature decorrelation、Consistency(半监督) |
二、可直接用的 PyTorch “Loss Zoo” 🧰
- 说明:每个类/函数都可独立复制;
logits
指未过 softmax 的网络输出;多标签用BCEWithLogits
以避免数值不稳。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F# ========== A. 回归 / 密集预测 ==========
class HuberLoss(nn.Module):def __init__(self, delta=1.0, reduction='mean'):super().__init__(); self.delta, self.reduction = delta, reductiondef forward(self, pred, target):diff = pred - targetabsd = diff.abs()quad = 0.5 * (diff ** 2)lin = self.delta * (absd - 0.5 * self.delta)loss = torch.where(absd <= self.delta, quad, lin)return loss.mean() if self.reduction == 'mean' else loss.sum()class LogCoshLoss(nn.Module):def __init__(self, reduction='mean'):super().__init__(); self.reduction = reductiondef forward(self, pred, target):x = pred - target# log(cosh(x)) = x^2/2 - log(1+exp(-2|x|)) + |x| - log 2 (数值稳定)loss = x**2 / 2 - torch.log1p(torch.exp(-2 * x.abs())) + x.abs() - math.log(2.0)return loss.mean() if self.reduction == 'mean' else loss.sum()class CharbonnierLoss(nn.Module): # L1 的可微近似def __init__(self, eps=1e-3, reduction='mean'):super().__init__(); self.eps, self.reduction = eps, reductiondef forward(self, pred, target):loss = torch.sqrt((pred - target)**2 + self.eps**2)return loss.mean() if self.reduction == 'mean' else loss.sum()class QuantileLoss(nn.Module): # pinball,支持多分位def __init__(self, quantiles=(0.1, 0.5, 0.9), reduction='mean'):super().__init__(); self.qs = torch.tensor(quantiles); self.reduction = reductiondef forward(self, pred, target):# pred: [B, Q], target: [B]if pred.ndim == 1: pred = pred.unsqueeze(-1)qs = self.qs.to(pred.device)diff = target.unsqueeze(1) - predloss = torch.max(qs * diff, (qs - 1) * diff)return loss.mean() if self.reduction == 'mean' else loss.sum()class HeteroscedasticGaussianNLL(nn.Module):""" 预测均值μ与对数方差s=log σ^2,最小化 NLL = 0.5*exp(-s)*(y-μ)^2 + 0.5*s """def __init__(self, reduction='mean'):super().__init__(); self.reduction = reductiondef forward(self, mu, log_var, target):inv_var = torch.exp(-log_var)loss = 0.5 * (inv_var * (target - mu)**2 + log_var)return loss.mean() if self.reduction == 'mean' else loss.sum()# ========== B. 分类(单/多标签) ==========
class LabelSmoothingCE(nn.Module):def __init__(self, eps=0.1, weight=None, reduction='mean'):super().__init__(); self.eps, self.weight, self.reduction = eps, weight, reductiondef forward(self, logits, target):n = logits.size(-1)logp = F.log_softmax(logits, dim=-1)with torch.no_grad():t = torch.full_like(logits, self.eps/(n-1))t.scatter_(1, target.unsqueeze(1), 1-self.eps)if self.weight is not None:w = self.weight[target].unsqueeze(1)loss = -(w * t * logp).sum(dim=1)else:loss = -(t * logp).sum(dim=1)return loss.mean() if self.reduction == 'mean' else loss.sum()class FocalLoss(nn.Module): # 适合长尾/困难样本def __init__(self, gamma=2.0, alpha=None, reduction='mean'):super().__init__(); self.g, self.alpha, self.reduction = gamma, alpha, reductiondef forward(self, logits, target):logp = F.log_softmax(logits, dim=-1)pt = logp.gather(1, target.unsqueeze(1)).exp().squeeze(1)if self.alpha is not None:a = self.alpha.to(logits.device)[target]else:a = 1.0loss = - a * (1-pt)**self.g * torch.log(pt + 1e-12)return loss.mean() if self.reduction == 'mean' else loss.sum()class AsymmetricLossMultiLabel(nn.Module):""" ASL: 针对多标签不平衡的 BCE 变体(论文: Asymmetric Loss For Multi-Label Classification) """def __init__(self, gamma_pos=0.0, gamma_neg=4.0, clip=0.05, reduction='mean'):super().__init__(); self.gp, self.gn, self.clip, self.reduction = gamma_pos, gamma_neg, clip, reductiondef forward(self, logits, targets):# logits/targets: [B, C], targets∈{0,1}xs_pos = F.logsigmoid(logits) # log σ(x)xs_neg = F.logsigmoid(-logits) # log (1-σ(x))if self.clip is not None and self.clip > 0:xs_neg = torch.clamp(xs_neg, min=math.log(self.clip))# 按正负样本不同调制因子loss_pos = torch.pow(1 - torch.exp(xs_pos), self.gp) * (-xs_pos) * targetsloss_neg = torch.pow(1 - torch.exp(xs_neg), self.gn) * (-xs_neg) * (1 - targets)loss = loss_pos + loss_negreturn loss.mean() if self.reduction == 'mean' else loss.sum()class HingeLossMulticlass(nn.Module): # SVM 风格def __init__(self, margin=1.0, reduction='mean'):super().__init__(); self.m, self.reduction = margin, reductiondef forward(self, logits, target):# logits: [B, C], target: [B]B, C = logits.shapetrue = logits[torch.arange(B), target].unsqueeze(1)margins = F.relu(self.m - true + logits)margins[torch.arange(B), target] = 0.0loss = margins.sum(dim=1)return loss.mean() if self.reduction == 'mean' else loss.sum()# ========== C. 分割 / 结构化 ==========
class DiceLoss(nn.Module):def __init__(self, eps=1e-6):super().__init__(); self.eps = epsdef forward(self, logits, targets):# 二类:logits[B,1,H,W] 或 [B,H,W];多类可对每类求 Dice 再平均probs = torch.sigmoid(logits) if logits.size(1) == 1 else F.softmax(logits, dim=1)if logits.size(1) == 1:probs = probs.squeeze(1); targets = targets.float()inter = (probs * targets).sum(dim=(1,2))denom = probs.sum(dim=(1,2)) + targets.sum(dim=(1,2))dice = (2*inter + self.eps) / (denom + self.eps)return 1 - dice.mean()else:# one-hot targetsB, C, H, W = logits.shapeoh = F.one_hot(targets, num_classes=C).permute(0,3,1,2).float()inter = (probs * oh).sum(dim=(0,2,3))denom = probs.sum(dim=(0,2,3)) + oh.sum(dim=(0,2,3))dice = (2*inter + self.eps) / (denom + self.eps)return 1 - dice.mean()class JaccardLoss(nn.Module): # IoU = TP/(TP+FP+FN)def __init__(self, eps=1e-6):super().__init__(); self.eps = epsdef forward(self, logits, targets):probs = torch.sigmoid(logits)targets = targets.float()inter = (probs * targets).sum(dim=(1,2,3))union = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) - interiou = (inter + self.eps) / (union + self.eps)return 1 - iou.mean()class TverskyLoss(nn.Module):def __init__(self, alpha=0.7, beta=0.3, eps=1e-6):super().__init__(); self.a, self.b, self.eps = alpha, beta, epsdef forward(self, logits, targets):p = torch.sigmoid(logits); t = targets.float()tp = (p*t).sum(dim=(1,2,3))fp = (p*(1-t)).sum(dim=(1,2,3))fn = ((1-p)*t).sum(dim=(1,2,3))tv = (tp + self.eps) / (tp + self.a*fp + self.b*fn + self.eps)return 1 - tv.mean()class FocalTverskyLoss(nn.Module):def __init__(self, alpha=0.7, beta=0.3, gamma=0.75, eps=1e-6):super().__init__(); self.core = TverskyLoss(alpha, beta, eps); self.g = gammadef forward(self, logits, targets):tv = 1 - self.core(logits, targets)return tv**self.g# ========== D. 度量学习 / 检索 ==========
class ContrastiveLoss(nn.Module): # Siamese: y∈{0(负),1(正)}def __init__(self, margin=1.0, reduction='mean'):super().__init__(); self.m, self.reduction = margin, reductiondef forward(self, z1, z2, y):d = (z1 - z2).pow(2).sum(dim=1).clamp(min=1e-12).sqrt()loss = y * d.pow(2) + (1 - y) * F.relu(self.m - d).pow(2)return loss.mean() if self.reduction == 'mean' else loss.sum()class TripletLoss(nn.Module):def __init__(self, margin=0.2, reduction='mean'):super().__init__(); self.m, self.reduction = margin, reductiondef forward(self, za, zp, zn):pos = (za - zp).pow(2).sum(dim=1)neg = (za - zn).pow(2).sum(dim=1)loss = F.relu(pos - neg + self.m)return loss.mean() if self.reduction == 'mean' else loss.sum()class NTXentLoss(nn.Module): # SimCLR / SupCon 基础def __init__(self, temperature=0.2):super().__init__(); self.t = temperaturedef forward(self, z): # z: [2B, D],每对样本有两种增广z = F.normalize(z, dim=1)sim = torch.matmul(z, z.t()) / self.t # [2B,2B]mask = torch.eye(sim.size(0), dtype=torch.bool, device=z.device)sim = sim - 1e9 * mask # 排除自身B2 = z.size(0)targets = torch.arange(B2, device=z.device)targets = (targets + (B2//2)) % B2 # 正样本索引loss = F.cross_entropy(sim, targets)return lossclass ArcFaceLoss(nn.Module): # 角度间隔(人脸/检索常用)def __init__(self, s=30.0, m=0.50):super().__init__(); self.s, self.m = s, mdef forward(self, logits, target):# logits = cosθ(需提前 L2 归一化特征与权重)theta = torch.acos(logits.clamp(-1+1e-7, 1-1e-7))logits_m = torch.cos(theta + self.m)onehot = F.one_hot(target, num_classes=logits.size(1)).float()out = self.s * (onehot * logits_m + (1 - onehot) * logits)return F.cross_entropy(out, target)class CenterLoss(nn.Module):"""鼓励类内紧凑:∑||x - c_y||^2;训练时需配合 CE,并手动更新类中心。"""def __init__(self, num_classes, feat_dim, lr=0.5):super().__init__(); self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))self.lr = lrdef forward(self, feats, labels):c = self.centers[labels]return ((feats - c)**2).sum(dim=1).mean()# ========== E. 排序 / 推荐 ==========
class PairwiseHingeLoss(nn.Module): # y_i 应排在 y_j 之前def __init__(self, margin=1.0):super().__init__(); self.m = margindef forward(self, s_pos, s_neg):return F.relu(self.m - (s_pos - s_neg)).mean()class BPRLoss(nn.Module): # Bayesian Personalized Rankingdef forward(self, s_pos, s_neg): # 物品打分return - torch.log(torch.sigmoid(s_pos - s_neg) + 1e-12).mean()# ========== F. 生成建模 ==========
class VAELoss(nn.Module):""" 重建(MSE 或 BCE)+ KL( q(z|x) || N(0,1) ) """def __init__(self, recon='mse', beta=1.0):super().__init__(); self.recon = recon; self.beta = betadef forward(self, x, x_recon, mu, logvar):if self.recon == 'bce':rec = F.binary_cross_entropy_with_logits(x_recon, x, reduction='sum') / x.size(0)else:rec = F.mse_loss(torch.sigmoid(x_recon), x, reduction='mean')kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())return rec + self.beta * klclass GANLoss(nn.Module):""" mode='vanilla' 用 BCE;mode='hinge' 用 hinge;判别器与生成器分开调用 """def __init__(self, mode='vanilla'):super().__init__(); self.mode = modedef d_loss(self, real_logits, fake_logits):if self.mode == 'vanilla':return (F.binary_cross_entropy_with_logits(real_logits, torch.ones_like(real_logits)) +F.binary_cross_entropy_with_logits(fake_logits, torch.zeros_like(fake_logits)))else: # hingereturn (F.relu(1 - real_logits).mean() + F.relu(1 + fake_logits).mean())def g_loss(self, fake_logits):if self.mode == 'vanilla':return F.binary_cross_entropy_with_logits(fake_logits, torch.ones_like(fake_logits))else:return - fake_logits.mean()def wgan_gp(discriminator, real, fake, lambda_gp=10.0):""" WGAN-GP 的梯度惩罚项 """B = real.size(0)eps = torch.rand(B, 1, 1, 1, device=real.device)inter = eps * real + (1 - eps) * fakeinter.requires_grad_(True)d_inter = discriminator(inter)grad = torch.autograd.grad(d_inter.sum(), inter, create_graph=True)[0]gp = ((grad.view(B, -1).norm(2, dim=1) - 1)**2).mean()return lambda_gp * gp# Diffusion: 常用 ε-预测 MSE 或 v-param MSE,训练时即 MSE(ε_pred, ε)# ========== G. 自监督 ==========
class BarlowTwinsLoss(nn.Module):""" 交叉相关矩阵接近单位:对角项→1,非对角→0 """def __init__(self, lambda_offdiag=5e-3):super().__init__(); self.lmb = lambda_offdiagdef forward(self, z1, z2, eps=1e-9):z1 = (z1 - z1.mean(0)) / (z1.std(0) + eps)z2 = (z2 - z2.mean(0)) / (z2.std(0) + eps)c = (z1.T @ z2) / z1.size(0) # [D,D]on = torch.diagonal(c).add_(-1).pow_(2).sum()off = (c.fill_diagonal_(0) or c).pow(2).sum() # 仅非对角return on + self.lmb * off# ========== H. 序列 / 语音 ==========
# CTC 直接用 nn.CTCLoss(blank=0, reduction='mean'); 自回归用 CrossEntropy# ========== I. 强化学习(核心片段) ==========
class PPOClipLoss:""" 仅策略损失片段:L = -E[min(r*A, clip(r)*A)],带熵正则可另外加 """def __init__(self, clip_ratio=0.2, entropy_coef=0.0):self.eps = clip_ratio; self.ent = entropy_coefdef __call__(self, logp_new, logp_old, advantage, entropy):r = (logp_new - logp_old).exp()unclipped = r * advantageclipped = torch.clamp(r, 1 - self.eps, 1 + self.eps) * advantageloss = - torch.min(unclipped, clipped).mean() - self.ent * entropy.mean()return loss
- 其他内置:
nn.CrossEntropyLoss
、nn.BCEWithLogitsLoss
、nn.CosineEmbeddingLoss
、nn.TripletMarginLoss
、nn.GaussianNLLLoss
、nn.CTCLoss
等均可直接调用。