当前位置: 首页 > news >正文

19.训练模式、评估模式

model.train()
model.eval()

训练模式

model.train() —— 训练模式

Dropout 层:会被激活,随机丢弃一部分神经元(按设定的概率 p),防止过拟合。

BatchNorm 层:使用当前批次的均值和方差进行标准化,并更新全局统计量(移动平均)。

其他行为:梯度计算开启(requires_grad=True),允许反向传播更新参数。

使用场景:在训练时调用,通常配合 optimizer.zero_grad() 和 loss.backward() 使用。

model.train()  # 切换到训练模式
for data, label in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, label)loss.backward()optimizer.step()

评估模式

model.eval() —— 评估模式

Dropout 层:会被关闭,所有神经元均参与计算(不再随机丢弃)。

BatchNorm 层:使用训练时积累的全局均值和方差(而非当前批次的统计量),保证结果稳定。

梯度计算:通常关闭(可通过 torch.no_grad() 进一步禁用,节省内存)。

model.eval()  # 切换到评估模式
with torch.no_grad():  # 禁用梯度计算for data, label in test_loader:output = model(data)accuracy = (output.argmax(dim=1) == label).sum().item()

为什么需要区分模式?

  • 训练模式:需要随机性和参数更新(如 Dropout 和 BatchNorm 依赖批次统计)。

  • 评估模式:追求确定性和一致性(避免 Dropout 的随机性,固定 BatchNorm 的统计量)。

不区分模式可能会怎样?

训练时不写 model.train()

Dropout层:

如果模型初始化后未显式设置模式,某些情况下可能默认是训练模式(依赖具体实现),但不要依赖这种隐式行为。

如果模型之前处于eval()模式,Dropout会被关闭,神经元不会被随机丢弃,导致模型可能过拟合(因为失去了正则化作用)。

BatchNorm层:

如果未显式设置,BatchNorm可能默认按训练模式运行(更新统计量),但存在不确定性。

如果模型之前处于eval()模式,BatchNorm会使用训练阶段积累的全局均值和方差(而非当前批次的统计量),导致训练时参数更新不正确,可能收敛变慢或性能下降。

评估时不写model.eval()

Dropout层:

如果模型之前处于train()模式,Dropout会继续随机丢弃神经元,导致评估结果随机波动(同一输入多次预测结果可能不同)。

BatchNorm层:

如果处于train()模式,BatchNorm会使用当前批次的均值和方差(而非训练时积累的全局统计量),导致评估指标不准确(尤其是batch较小时,方差可能剧烈波动)。

torch.no_grad() 和 eval() 的区别:

  • eval() 只影响特定层(Dropout/BatchNorm)。

  • no_grad() 是全局禁用梯度计算,节省内存,常用于评估和推理。

http://www.dtcms.com/a/335901.html

相关文章:

  • 基于遗传编程的自动程序生成
  • JAVA面试汇总(四)JVM(二)
  • pytorch线性回归
  • 7 索引的监控
  • 数学建模 14 中心对数比变换
  • 定时器中断点灯
  • Redux搭档Next.js的简明使用教程
  • 安卓开发中遇到Medium Phone API 36.0 is already running as process XXX.
  • 突破Python性能墙:关键模块C++化的爬虫优化指南
  • 【牛客刷题】字符串按索引二进制1个数奇偶性转换大小写
  • 编程算法实例-整数分解质因数
  • Vue3 + Element Plus 人员列表搜索功能实现
  • UE5多人MOBA+GAS 48、制作闪现技能
  • 第三十九天(WebPack构建打包Mode映射DevTool源码泄漏识别还原)
  • 软件开发 - foreground 与 background
  • 电容,三极管,场效应管
  • 光耦,发声器件,继电器,瞬态抑制二极管
  • 【102页PPT】新一代数字化转型信息化总体规划方案(附下载方式)
  • Coin与Token的区别解析
  • Python爬虫-解决爬取政务网站的附件,找不到附件链接的问题
  • 数学建模-评价类问题-优劣解距离法(TOPSIS)
  • 博士招生 | 新加坡国立大学 SWEET实验室 招收人机交互方向 博士/博士后
  • 13.web api 4
  • 实现用户输入打断大模型流式输出:基于Vue与FastAPI的方案
  • 基于DSP+ARM+FPGA架构的储能协调控制器解决方案,支持全国产化
  • Diamond基础2:开发流程之LedDemo
  • JavaScirpt高级程序设计第三版学习查漏补缺(1)
  • vba学习系列(12)--反射率通过率计算复杂度优化25/8/17
  • Nacos 注册中心学习笔记
  • Yolov模型的演变