当前位置: 首页 > news >正文

BN测试和训练时有什么不同, 在测试时怎么使用?

我们来彻底搞懂 Batch Normalization(BN) 在训练和测试阶段的区别,以及 测试时怎么用


🧠 一句话总结:

训练时:使用 当前 mini-batch 的均值和方差
测试时:使用 整个训练集估计的“滑动平均均值和方差”

🧪 训练阶段的 BN(model.train()):

1. 对每个 mini-batch:

• 计算该 batch 的均值 μ 和方差 σ²

• 进行标准化:

\hat{x} = \frac{x - \mu_{\text{batch}}}{\sqrt{\sigma^2_{\text{batch}} + \epsilon}}

• 然后进行缩放和平移:

y = \gamma \hat{x} + \beta

2. 同时更新全局滑动均值和滑动方差(moving mean/var)用于测试。


✅ 测试阶段的 BN(model.eval()):

1. 不再用当前 batch 的统计量(因为测试数据可能 batch 很小,甚至 batch=1,统计不稳定)

2. 使用训练时保存下来的滑动均值和方差来做归一化:

\hat{x} = \frac{x - \mu_{\text{moving}}}{\sqrt{\sigma^2_{\text{moving}} + \epsilon}}


🔁 在 PyTorch 中怎么做?

训练时:

model.train()  # 启用 dropout、BN 用 batch 统计

测试/验证时:

model.eval()   # 关闭 dropout,BN 用全局 moving 均值/方差

你想实际测试一个 BN 层,可以这样

import torch
import torch.nn as nn

bn = nn.BatchNorm2d(3)

x = torch.randn(4, 3, 8, 8)

# 训练时
bn.train()
out_train = bn(x)

# 测试时
bn.eval()
out_test = bn(x)

📌 小结表格

模式

统计方式

是否更新均值/方差

适用场景

train()

当前 mini-batch

✅ 是

训练

eval()

滑动平均值

❌ 否

推理、验证


全局滑动均值和滑动方差是什么

全局滑动均值(moving mean)和滑动方差(moving variance)是 Batch Normalization(BN) 在训练过程中用于统计数据分布的全局估计值。它们通过 指数加权移动平均 的方式进行更新,并在 测试阶段 使用。具体而言,它们用于避免测试时每次计算时都依赖小批量数据(mini-batch)的统计信息。


🧠 具体解释:

1. 滑动均值(Moving Mean):

• 训练时,BN层 会不断计算当前 batch 的均值(mean)和方差(variance)。

• 但是,为了让测试时能够使用更稳定的均值和方差,BN会 维护一个滑动均值

• 这个均值是通过以下公式更新的:

\mu_{\text{moving}}^{t} = \alpha \cdot \mu_{\text{batch}}^t + (1 - \alpha) \cdot \mu_{\text{moving}}^{t-1}

其中,α 是一个常数(通常接近1,比如 0.9 或 0.99),用于控制历史统计对当前的影响。

2. 滑动方差(Moving Variance):

• 类似于滑动均值,BN层也会维护一个 滑动方差,它是通过以下公式更新的:

• 这里的σ² 是当前 batch 的方差。

3. 为什么要使用滑动均值和滑动方差?

稳定性:在测试时,数据通常是一次性输入模型,batch size 可能为 1 或者非常小,这会导致当前 batch 的统计量(均值和方差)非常不稳定。如果每次测试时都重新计算均值和方差,可能会导致不稳定的推理结果。

全局统计:滑动均值和方差通过结合历史数据的统计量,提供了一个更为稳定的全局估计,这有助于在整个训练过程中产生更一致的推理结果。


🚀 总结:

训练阶段:每个 mini-batch 计算新的均值和方差,并且更新滑动均值和滑动方差。

测试阶段:使用训练中计算得到的滑动均值和滑动方差,而不再使用当前 batch 的均值和方差。

小结表格:

阶段

使用的均值/方差

更新均值/方差

训练

当前 batch 的均值和方差

✅ 更新滑动均值/方差

测试

滑动均值和滑动方差

❌ 不更新

http://www.dtcms.com/a/113560.html

相关文章:

  • 【C++项目】从零实现RPC框架「四」:业务层实现与项目使用
  • 【51单片机】2-7【I/O口】点亮数码管
  • 线程池的工作原理
  • 线代[12]|《高等几何》陈绍菱(1984.9)(文末有对三大空间的分析及一个合格数学系毕业生的要求)
  • Python 语法学习 1(类比 java 学习)-附Python 中 self
  • 前端用用jsonp的方式解决跨域问题
  • [ICLR 2025]Biologically Plausible Brain Graph Transformer
  • Reids 的io并发模型
  • 程序化广告行业(60/89):算法优化与DSP系统实例解析
  • Linux系统程序设计:从入门到高级Day03
  • 第八章:流量治理_《凤凰架构:构建可靠的大型分布式系统》
  • DDPM 做了什么
  • 2007-2019年各省地方财政其他支出数据
  • 格式工厂怎样插入内置音频文件
  • 硬件工程师面试问题(五):蓝牙面试问题与详解
  • 在响应式网页的开发中使用固定布局、流式布局、弹性布局哪种更好
  • vllm作为服务启动,无需额外编写sh文件,一步到位【Ubuntu】
  • 『Linux_网络』 第一章 网络基础概念
  • 分表字段选择策略:以电商交易订单为例的最佳实践
  • Java项目之基于ssm的怀旧唱片售卖系统(源码+文档)
  • 大数据时代的隐私保护:区块链技术的创新应用
  • 通过构造函数和几何条件,研究了不同函数的最近点存在性、性质及单调性
  • ZKmall开源商城多云高可用架构方案:AWS/Azure/阿里云全栈实践
  • 紧急更新!MinIO发布RELEASE.2025-04-03T14-56-28Z版本,修复高危漏洞CVE-2025-31489,用户需立即升级!
  • raft协议中一条数据写入流程
  • Java 实现插入排序:[通俗易懂的排序算法系列之三]
  • 文献总结:TPAMI综述BEV感知—Delving into the devils of bird‘s-eye-view perception
  • Socket编程TCP
  • HarmonyOS:WebView 控制及 H5 原生交互实现
  • 硬件学习之器件篇-蜂鸣器