当前位置: 首页 > news >正文

深度学习中神经网络与损失函数优化

一、神经网络基础

1.核心概念

        人工神经网络(ANN):模仿人脑神经元结构的计算模型,核心是 “人工神经元”。

                生物神经元:树突接收信号→细胞核处理→轴突输出信号(需达到阈值才激活)。

                人工神经元:对输入x加权求和(Σx*w + bw是权重、b是偏置),再通过激活函数输出,公式为 output = f(Σx*w + b)

2.神经网络结构

网络层级作用关键特点
输入层接收原始数据(如图片像素、手机参数)神经元数量 = 输入特征数(如手机数据有 20 个特征,输入层就有 20 个神经元)
隐藏层提取数据特征(复杂逻辑由多层隐藏层实现)同一层无连接,与相邻层全连接(第 N 层每个神经元连第 N-1 层所有神经元)
输出层输出预测结果(分类 / 回归)分类任务:神经元数量 = 类别数(如手机价格 4 分类,输出层 4 个神经元);回归任务:1 个神经元

二、激活函数

1.核心作用

        没有激活函数:网络本质是 “线性模型”,无法拟合复杂数据(如房价与面积的非线性关系)。

        有激活函数:通过非线性变换,让网络能逼近任意复杂函数,解决复杂任务。

2.常见激活函数对比

激活函数公式输出范围优点缺点适用场景
Sigmoidf(x)=1/(1+e^-x)(0,1)输出可看作概率,适合二分类梯度易消失(输入 <-6 或> 6 时导数接近 0)、非 0 中心仅二分类任务的输出层
Tanhf(x)=(1-e^-2x)/(1+e^-2x)(-1,1)0 中心,梯度比 Sigmoid 大,收敛更快仍有梯度消失问题早期隐藏层(现在少用)
ReLUf(x)=max(0,x)[0,∞)计算快(无需指数运算)、缓解梯度消失(x>0 时导数 = 1)存在 “神经元死亡”(x<0 时权重无法更新)首选隐藏层(90% 以上场景用)
SoftMaxf(x_i)=e^x_i/Σe^x_j(0,1)(所有输出和为 1)输出是类别概率,清晰反映分类置信度仅用于分类,不适合隐藏层多分类任务的输出层

3.选择口诀

隐藏层:优先 ReLU → 效果差换 Leaky ReLU(解决神经元死亡);

输出层:二分类用 Sigmoid、多分类用 SoftMax、回归用 “恒等函数”(直接输出,无激活)。

三、参数初始化:避免“训练崩溃”的关键

1.核心问题

全0/全1初始化:所有神经元输出相同,梯度相同,权重无法更新(“对称权重问题”);

随机值太大:激活函数输出趋近于 0 或 1,梯度消失,训练不动。

2.常用初始化方法(pytorch实现)

方法原理适用场景代码示例
均匀分布(-1/√d, 1/√d)随机取值(d = 输入神经元数)通用场景nn.init.uniform_(linear.weight)
正态分布均值 = 0,标准差 = 1,取小值通用场景nn.init.normal_(linear.weight, mean=0, std=1)
He 初始化(Kaiming)正态分布:std=√(2/d);均匀分布:limit=√(6/d)搭配 ReLU 激活函数(隐藏层)nn.init.kaiming_normal_(linear.weight)
Xavier 初始化正态分布:std=√(2/(d_in+d_out));均匀分布:limit=√(6/(d_in+d_out))搭配 Sigmoid/Tanh 激活函数

nn.init.xavier_normal_(linear.weight)

3.注:pytorch层(如nn.Linear)有默认初始化,但复杂网络建议手动用 kaiming/Xavier。

四、损失函数

损失函数越小,模型预测越准,核心分“分类任务”和回归任务两类。

1.分类任务

损失函数公式适用场景PyTorch 代码
多分类交叉熵L=-Σy_true*log(y_pred)(y_pred 是 SoftMax 输出)多分类(类别互斥)nn.CrossEntropyLoss()(内置 SoftMax,输入无需先过 SoftMax)
二分类交叉熵L=-y_true*log(y_pred) - (1-y_true)*log(1-y_pred)二分类nn.BCELoss()(输入需先过 Sigmoid,输出是概率)

2.回归任务(预测连续值,如房价、温度)

损失函数公式优点缺点PyTorch 代码
MAE(L1 损失)`L=Σy_true - y_pred/n`对异常值鲁棒(不怕极端值)梯度在 0 点不平滑,易跳过最优解nn.L1Loss()
MSE(L2 损失)L=Σ(y_true - y_pred)²/n梯度平滑,训练稳定对异常值敏感(极端值会放大损失)nn.MSELoss()
Smooth L1分段函数:x<1 时0.5x²x≥1 时 `x-0.5`结合 MAE 和 MSE 优点,鲁棒且平滑计算略复杂nn.SmoothL1Loss()

五、网络优化:让模型“快速找到最优解”

核心解决“梯度下降慢、卡鞍点、局部最小值问题”,分“优化器”和“学习率衰减”两类。

1.梯度下降优化器(核心是怎么更新权重)

优化器原理优点适用场景PyTorch 代码
SGD(随机梯度下降)w = w - lr*grad(每次用一个 Batch 更新)简单,通用基础场景,需搭配动量optim.SGD([w], lr=0.01, momentum=0.9)
Momentum(动量)累加历史梯度:grad_total = β*grad_prev + (1-β)*grad_curr缓解震荡,加速收敛,易跨过鞍点所有场景,尤其是数据震荡时同上(加 momentum 参数)
AdaGrad自适应学习率:lr_i = lr/√(Σgrad_i² + ε)(梯度大的参数,学习率小)适合稀疏数据(如文本)后期学习率可能过小,停在次优解optim.Adagrad([w], lr=0.01)
RMSProp改进 AdaGrad:用指数加权平均替代历史梯度和解决 AdaGrad 学习率衰减过快问题通用场景,尤其是非凸优化optim.RMSprop([w], lr=0.01, alpha=0.9)
Adam结合 Momentum(梯度平滑)和 RMSProp(自适应学习率)收敛快、稳定,几乎所有场景最优推荐首选(90% 以上任务用)optim.Adam([w], lr=0.01, betas=[0.9, 0.99])

2.学习率衰减(避免后期学习率太大,跳过最优解)

等间隔衰减:每N个Epoch,学习率*gamma,代码:

optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

指定间隔衰减:在特定 Epoch(如 50、125 轮)衰减,代码:optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,125], gamma=0.5)

指数衰减:每轮学习率 ×gamma(如 gamma=0.95),代码:optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

六、正则化:缓解“过拟合”

1.Dropout(随机失活)

原理:训练时,让每个神经元以概率p(通常 0.4-0.6)“失效”(输出置 0),未失效的神经元输出 ×1/(1-p)(保证期望不变);测试时,所有神经元都生效,不置 0。

作用:避免神经元过度依赖某几个输入,增强泛化能力。

代码:nn.Dropout(p=0.4)(只在训练时用,测试时需关 Dropout:model.eval())。

2.BN层(批量归一化)

原理:对每一层的输入数据做 “标准化”(均值 = 0,方差 = 1),再通过可学习参数γ(缩放)和β(平移)调整,公式:output = γ*(x-mean)/√(var+ε) + β

作用:加速训练(允许更高学习率)、缓解过拟合、减少初始化影响。

适用场景:计算机视觉(CNN)、深层全连接网络,代码:nn.BatchNorm1d(num_features=128)(1d 对应全连接,2d 对应 CNN)。

七、总结

网络结构:输入→隐藏(ReLU)→输出(分类用 SoftMax / 二分类 Sigmoid,回归直接输出);

训练三要素:损失函数(分类交叉熵、回归 MAE/MSE)、优化器(首选 Adam)、正则化(Dropout+BN);

实践流程:数据准备→模型构建→训练(反向传播)→评估→调优。

http://www.dtcms.com/a/392180.html

相关文章:

  • 整体设计 完整的逻辑链条 之1 点dots/线lines/面faces 的三曲:三进三出的三个来回
  • 微调基本理论
  • LeetCode算法日记 - Day 48: 课程表II、火星词典
  • 【面板数据】地级市中国方言多样性指数数据集
  • C++编程学习(第35天)
  • SS443A 霍尔效应传感器:高性能磁感应解决方案
  • MIT新论文:数据即上限,扩散模型的关键能力来自图像统计规律,而非复杂架构
  • GitHub 热榜项目 - 日榜(2025-09-20)
  • 怎么判断 IP是独享的
  • Linux多进程编程(上)
  • 如何在Spring Boot项目中添加自定义的配置文件?
  • 【MySQL初阶】01-MySQL服务器和客户端下载与安装
  • AI搜索的下一站:多模态、个性化与GEO的道德指南
  • OpenLayers地图交互 -- 章节四:修改交互详解
  • Gradle插件的分析与使用
  • 如何避免everything每次都重建索引
  • 基于SIFT+flann+RANSAC+GTM算法的织物图像拼接matlab仿真,对比KAZE,SIFT和SURF
  • 笔记:现代操作系统:原理与实现(3)
  • 【智能系统项目开发与学习记录】Docker 基础
  • 数据展示方案:Prometheus+Grafana+JMeter 备忘
  • flask获取ip地址各种方法
  • 17.6 LangChain多模态实战:语音图像文本融合架构,PPT生成效率提升300%!
  • MyBatis实战教程:SQL映射与动态查询技巧
  • 在 Windows Docker 中通过 vLLM 镜像启动指定大模型的方法与步骤
  • 分类预测 | Matlab实现SSA-BP麻雀搜索算法优化BP神经网络多特征分类预测
  • GO实战项目:基于 `HTML/CSS/JS + Gin + Gorm + 文心一言API`AI 备忘录应用
  • 数据结构【堆(⼆叉树顺序结构)和⼆叉树的链式结构】
  • 我爱学算法之—— 位运算(下)
  • LeetCode第364题_加权嵌套序列和II
  • 云计算和云手机之间的关系