当前位置: 首页 > news >正文

深度学习之参数初始化和损失函数(四)

深度学习之参数初始化和损失函数(四)


文章目录

  • 深度学习之参数初始化和损失函数(四)
    • 一、参数初始化
      • 1.1 固定值初始化(仅演示,**权重不要用**)
      • 1.2 随机初始化(打破对称性的起点)
      • 1.3 Xavier(Glorot)初始化 —— 均衡前向与反向方差
      • 1.4 He(Kaiming)初始化 —— 专为 ReLU 优化
      • 1.5 高级初始化速览
    • 二、损失函数
      • 2.1 回归任务
      • 2.2 分类任务
        • 2.2.1 多类单标签 —— CrossEntropyLoss
        • 2.2.2 二分类 / 多标签 —— BCEWithLogitsLoss
    • 三、总结
    • 四、案例


关键词:参数初始化、Xavier / He 初始值、对称性破坏、损失函数、MAE / MSE / CrossEntropy / BCE



一、参数初始化

1.1 固定值初始化(仅演示,权重不要用

方法代码缺陷
全零nn.init.zeros_(w)对称性未被破坏,所有神经元等价
全一nn.init.ones_(w)同上,且激活后输出恒等
任意常数nn.init.constant_(w, val)仍无法打破对称性
import torch.nn as nn
fc = nn.Linear(4, 3)
nn.init.zeros_(fc.weight)      # 仅偏置可用

1.2 随机初始化(打破对称性的起点)

分布代码方差公式备注
均匀nn.init.uniform_(w, a, b)(b−a)212\displaystyle \frac{(b-a)^2}{12}12(ba)2需手动调区间
正态nn.init.normal_(w, mean, std)σ2\displaystyle \sigma^2σ2std 难校准

局限:未考虑前向/反向方差,深层网络仍需“自适应”方法。


1.3 Xavier(Glorot)初始化 —— 均衡前向与反向方差

激活函数数学依据PyTorch API
Sigmoid / TanhVar(W)=2nin+nout\displaystyle \text{Var}(W)=\frac{2}{n_{\text{in}}+n_{\text{out}}}Var(W)=nin+nout2xavier_uniform_ / xavier_normal_
  • 均匀区间[−6nin+nout,6nin+nout][-\sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}},\; \sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}}][nin+nout6,nin+nout6]
  • 正态方差2nin+nout\frac{2}{n_{\text{in}}+n_{\text{out}}}nin+nout2
# 示例:对 Tanh 使用 Xavier
fc = nn.Linear(128, 64)
nn.init.xavier_uniform_(fc.weight, gain=nn.init.calculate_gain('tanh'))

1.4 He(Kaiming)初始化 —— 专为 ReLU 优化

模式方差公式场景API
fan_in2nin\frac{2}{n_{\text{in}}}nin2前向方差稳定kaiming_normal_(..., mode='fan_in')
fan_out2nout\frac{2}{n_{\text{out}}}nout2反向梯度稳定kaiming_uniform_(..., mode='fan_out')
# 示例:ReLU + He
fc = nn.Linear(256, 128)
nn.init.kaiming_normal_(fc.weight, nonlinearity='relu')

1.5 高级初始化速览

方法一句话说明代码
orthogonal_生成(半)正交矩阵,保持动态等距nn.init.orthogonal_(w)
sparse_指定稀疏度的高斯权重nn.init.sparse_(w, sparsity=0.9)

二、损失函数

2.1 回归任务

损失公式PyTorch
MAE (L1)$\frac{1}{n}\sumy_i-\hat y_i
MSE (L2)1n∑(yi−y^i)2\frac{1}{n}\sum(y_i-\hat y_i)^2n1(yiy^i)2nn.MSELoss()
pred   = torch.randn(32, 1)
target = torch.randn(32, 1)
print("MSE:", nn.MSELoss()(pred, target).item())
print("MAE:", nn.L1Loss()(pred, target).item())

2.2 分类任务

2.2.1 多类单标签 —— CrossEntropyLoss
  • 已内置 Softmax不要再手动 softmax
  • 公式:
    L=−1N∑ilog⁡exi,yi−max⁡(xi)∑jexi,j−max⁡(xi)\displaystyle \mathcal{L}=-\frac{1}{N}\sum_{i}\log\frac{e^{x_{i,y_i}-\max(x_i)}}{\sum_j e^{x_{i,j}-\max(x_i)}}L=N1ilogjexi,jmax(xi)exi,yimax(xi)
logits = torch.randn(16, 10)      # (batch, n_classes)
labels = torch.randint(0, 10, (16,))
loss = nn.CrossEntropyLoss()(logits, labels)
print("CrossEntropy:", loss.item())
2.2.2 二分类 / 多标签 —— BCEWithLogitsLoss
  • 已内置 Sigmoid,推荐一步到位。
  • 公式:
    −1N∑i[yilog⁡y^i+(1−yi)log⁡(1−y^i)]\displaystyle -\frac{1}{N}\sum_{i}\left[y_i\log\hat y_i+(1-y_i)\log(1-\hat y_i)\right]N1i[yilogy^i+(1yi)log(1y^i)]
logits  = torch.randn(8, 5)       # 8 个样本,5 个标签
targets = torch.randint(0, 2, (8, 5)).float()
loss = nn.BCEWithLogitsLoss()(logits, targets)
print("Multi-label BCE:", loss.item())

三、总结

任务类型输出层初始化损失备注
线性回归无激活Xavier / He 均可MSE / MAE输出无需激活
二分类1 个神经元 + SigmoidHeBCEWithLogitsLoss标签 0/1
多类单标签SoftmaxHeCrossEntropyLoss无需手动 Softmax
多标签Sigmoid(每个类)HeBCEWithLogitsLoss标签多热编码

四、案例

完整训练片段:初始化 + 损失

import torch
import torch.nn as nn
import torch.optim as optimclass MLP(nn.Module):def __init__(self, in_dim=784, hidden=256, out_dim=10):super().__init__()self.net = nn.Sequential(nn.Linear(in_dim, hidden),nn.ReLU(),nn.Linear(hidden, out_dim))self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')nn.init.zeros_(m.bias)def forward(self, x):return self.net(x)model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# 模拟一个 batch
x = torch.randn(64, 784)
y = torch.randint(0, 10, (64,))
out = model(x)
loss = criterion(out, y)optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Step loss:", loss.item())

http://www.dtcms.com/a/291273.html

相关文章:

  • 深入解析MIPI C-PHY (二)C-PHY三线魔术:如何用6种“符号舞步”榨干每一滴带宽?
  • 设计模式六:工厂模式(Factory Pattern)
  • C语言:20250721笔记
  • 在 Conda 中删除环境及所有安装的库
  • 《使用 IDEA 部署 Docker 应用指南》
  • Linux-rpm和yum
  • Shell脚本编程:从入门到精通的实战指南
  • 从零开始:用Python库轻松搭建智能AI代理
  • Djoser 详解
  • 深度学习中的数据增强:从理论到实践
  • hot100回归复习(算法总结1-38)
  • 力扣面试150(35/150)
  • 【安全篇 / 反病毒】(7.6) ❀ 01. 查杀HTTPS加密网站病毒 ❀ FortiGate 防火墙
  • Excel函数 —— XLOOKUP 双向查找
  • Linux find命令:强大的文件搜索工具
  • 计算机发展史:电子管时代的辉煌与局限
  • 无人机浆叶安装顺序
  • 【算法基础】二分查找
  • 源码编译安装boost库,以及卸载boost库
  • 插值法的使用
  • Js进阶案例合集
  • iostat的使用说明
  • 基于深度学习的图像分类:使用ResNet实现高效分类
  • (10)机器学习小白入门 YOLOv:YOLOv8-cls 模型评估实操
  • G7打卡——Semi-Supervised GAN
  • numpy库的基础知识
  • 【VASP】机器学习势概述
  • 5G/4G PHY SoC:RNS802,适用于集成和分解的小型蜂窝 RAN 架构。
  • 在github上搭建自己主页
  • Blender软件入门-了解软件界面