当前位置: 首页 > news >正文

【深入理解Batch Normalization(1)】原理与作用

1.为什么需要Normalization

深度学习网络模型训练困难的原因是,cnn包含很多隐含层,每层参数都会随着训练而改变优化,所以隐层的输入分布总会变化,每个隐层都会面临covariate shift的问题。

internal covariate shift(ICS)使得每层输入不再是独立同分布。这就造成,上一层数据需要适应新的输入分布,数据输入激活函数时,会落入饱和区,使得学习效率过低,甚至梯度消失。

Batch Normalization 的主要目的是通过减少内部协变量偏移(Internal Covariate Shift)​​ 来加速训练并提高模型的稳定性

2.Normalization的基本思想

由于cnn层数多,ICS会使激活输入分布偏移,落入饱和区,导致反向传播时出现梯度消失,这是训练收敛越来越慢的本质原因。而BN就是通过归一化手段,将每层输入强行拉回均值0方差为1的标准正态分布,这样使得激活输入值分布在非线性函数梯度敏感区域,从而避免梯度消失问题,大大加快训练速度。

在这里插入图片描述

如上图,sigmoid函数,BN使输入值分布在-1~1之间,在此区间梯度值大,有效避免梯度消失并提高收敛速度。

但是,归一化后,激活输入值均被分布于-1~1之间,这会导致非线性程度降低,夸张一点说,其实输入域的分布把原来的非线性函数转变成了线性函数。这意味着网络的表达能力下降了。因此BN为了保证非线性,对变换后的满足均值为0方差为1的x又进行了scale shift操作,即y=scale*x+shift。这两个参数通过训练获得,其实又将输入分布在标准正态分布的基础上进行了平移。其实就是为了在线性与非线性间找到平衡,让泛化能力与收敛能力最大程度的体现。

3、BN中均值、方差通过哪些维度计算得到?

神经网络中传递的张量数据,其维度通常记为[N, H, W, C],其中N是batch_size,H、W是行、列,C是通道数。那么上式中BN的输入集合就是下图中蓝色的部分。
在这里插入图片描述

  • 均值的计算,就是在一个批次内,将每个通道中的数字单独加起来,再除以 N×W×H。举个例子:该批次内有10张图片,每张图片有三个通道RBG,每张图片的高、宽是H、W,那么均值就是计算10张图片R通道的像素数值总和除以 10×W×H,再计算B通道全部像素值总和除以10×W×H,最后计算G通道的像素值总和除以10×W×H。
  • 方差的计算类似。
  • 可训练参数 γ , β 的维度等于张量的通道数,在上述例子中,RBG三个通道分别需要一个 γ 和一个 β ,所以 γ , β 的维度等于3。

4.训练BatchNorm

其核心操作是对每一批(Batch)数据的每个特征通道进行归一化:

•​计算一个批次数据的均值和方差​:对于每个特征通道,计算当前批次所有样本在该通道上所有值的均值(μ)和方差(σ²)。

•​归一化​:使用计算得到的均值和方差对该通道上的所有值进行归一化:x_hat = (x - μ) / sqrt(σ² + ε),其中 ε 是一个很小的数,防止除以零。

•​缩放和偏移​:引入两个可学习的参数 γ(缩放)和 β(偏移)​,对归一化后的值进行变换:y = γ * x_hat + β。这是为了保持模型的表达能力,避免归一化破坏原本已学到的特征分布。

同时,利用当前批次的统计量,通过指数移动平均(EMA) 来更新全局均值(μ)全局方差(σ²),公式一般为:new_running_mean = (1 - momentum) * running_mean + momentum * batch_mean

BN 层通常插入在卷积层(或线性层)和激活函数(如 ReLU)之间​。
在这里插入图片描述
在这里插入图片描述

每层BN参数是根据特征图的channel数来确定的。

5.不同模型中 BN 层数量举例

•​简单CNN​:一个只有几层的卷积神经网络可能只有 ​2-4 个 BN 层。

•​ResNet​:更深的网络如 ResNet-50 可能包含 ​几十个 BN 层​(例如,ResNet-50 有 53 个 BN 层)。

•​轻量级模型​:一些为移动设备设计的模型(如 MobileNet)可能会减少 BN 层的使用以降低计算量,但其数量仍然可观。

•​不使用BN的模型​:有些模型可能使用其他归一化技术(如 Layer Normalization, Group Normalization)或干脆不使用归一化层

BN 层数量的影响
BN 层能加速模型收敛、提供一定的正则化效果从而可能降低过拟合风险,并允许使用更高的学习率​。但其数量也并非越多越好:
•​计算开销​:BN 层会增加模型的计算量和训练时间。

•​小批量大小问题​:当训练时的批量大小(Batch Size)过小时,BN 层对均值和方差的估计会不准确,可能影响模型性能

6.BatchNorm推理(Inference)

参数类别参数名称是否可学习?数量 (基于特征维度 C)推理阶段行为
可学习参数缩放因子 (γ, gamma)C使用训练最终学到的固定值
偏移因子 (β, beta)C使用训练最终学到的固定值
非学习统计量全局均值 (running_mean)C使用训练阶段通过移动平均计算的固定值
全局方差 (running_var)C使用训练阶段通过移动平均计算的固定值
超参数ε (epsilon)1固定的小常数,用于数值稳定
动量 (momentum)1仅训练时用于更新统计量,推理时不使用

因此,对于一个特征维度为 C 的 BN 层,其参数总量4 * C + 2(4C个与维度相关的参数,加上2个超参数)。
BN层的参数可以分为可学习参数非学习的统计量两大类:

  1. 可学习参数 (Learned Parameters)

    • 缩放因子 (γ, gamma):一个维度为 num_features 的可学习向量。用于在标准化后恢复数据原本的表达能力,初始值通常为全1。
    • 偏移因子 (β, beta):一个维度为 num_features 的可学习向量。用于在标准化后恢复数据原本的表达能力,初始值通常为全0。
  2. 非学习的统计量 (Non-learned Statistics)

    • 全局均值 (running_mean):在训练过程中,通过指数移动平均 (EMA) 累积计算的整个训练数据集的均值估计,维度为 num_features
    • 全局方差 (running_var):在训练过程中,通过指数移动平均 (EMA) 累积计算的整个训练数据集的方差估计,维度为 num_features
  3. 超参数 (Hyperparameters)

    • ε (epsilon):一个很小的常数(例如 1e-5),添加到方差中以防止除以零,确保数值稳定性。
    • 动量 (momentum):用于控制指数移动平均 (EMA) 更新速度的超参数,决定当前批次的统计量对全局统计量的贡献程度,PyTorch 中默认为 0.1。需要注意的是,BN中的momentum与优化器中的momentum是不同的概念
  • 推理阶段
    • 不再使用当前批次的统计量,而是使用训练期间通过EMA累积得到的固定全局均值(running_mean)全局方差(running_var)
    • 标准化公式变为:x_hat = (x - running_mean) / sqrt(running_var + ε)
    • 同样使用训练好的、固定的参数 γβ 进行缩放和偏移:y = γ * x_hat + β
    • 这样做是为了确保推理结果的一致性和稳定性,避免因输入样本数量或内容不同而导致输出波动。

重要提醒

  • 参数固定:在推理阶段,BN层的所有参数(γ, β)和统计量(running_mean, running_var)都是固定的,直接使用训练阶段学习或计算好的值,不需要也不应该再更新
  • model.eval() 的重要性:在PyTorch等框架中,将模型设置为评估模式(model.eval())会自动切换BN层的行为到推理模式,使用 running_mean 和 running_var 并进行计算。
  • 训练模式的影响:如果模型在推理时意外处于训练模式(model.train()),BN层会尝试使用当前输入批次的统计量,这可能因为批次特性(如批次大小为1)导致性能下降或产生不一致的结果。

推理时,均值、方差是基于所有批次的期望计算所得,公式如下:
在这里插入图片描述
有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算NB进行变换了,在推理过程中进行BN采取如下方式:
在这里插入图片描述

beta、gamma在训练状态下,是可训练参数,在推理状态下,直接加载训练好的数值。moving_mean、moving_var在训练、推理中都是不可训练参数,只根据滑动平均计算公式更新数值,不会随着网络的训练BP而改变数值;在推理时,直接加载储存计算好的滑动平均之后的数值,作为推理时的均值和方差。

滑动平均,储存固定个数Batch的均值和方差,不断迭代更新推理时需要的E(x),Var(x)。

7.BatchNorm的作用

1.加快收敛速度,有效避免梯度消失。
2.提升模型泛化能力,BN的缩放因子可以有效的识别对网络贡献不大的神经元,经过激活函数后可以自动削弱或消除一些神经元。另外,由于归一化,很少发生数据分布不同导致的参数变动过大问题。

最后还想谈一谈Instance normalization

BN适用于判别模型中,比如图片分类模型。因为BN注重对每个batch进行归一化,从而保证数据分布的一致性,而判别模型的结果正是取决于数据整体分布。但是BN对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布;

IN适用于生成模型中,比如图片风格迁移,GAN等。因为图片生成的结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,在风格迁移中使用Instance Normalization不仅可以加速模型收敛,并且可以保持每个图像实例之间的独立。
在这里插入图片描述

上图中,从C方向看过去是指一个个通道,从N看过去是一张张图片。每6个竖着排列的小正方体组成的长方体代表一张图片的一个feature map。蓝色的方块是一起进行Normalization的部分。由此就可以很清楚的看出,Batch Normalization是指6张图片中的每一张图片的同一个通道一起进行Normalization操作。而Instance Normalization是指单张图片的单个通道单独进行Noramlization操作。

参考

https://blog.csdn.net/litt1e/article/details/105817224


文章转载自:

http://zO0KfDll.jrbyz.cn
http://KhyLfdcy.jrbyz.cn
http://0CiuPAIo.jrbyz.cn
http://XKMkATnV.jrbyz.cn
http://XdqGjLdJ.jrbyz.cn
http://tdYCV8Ph.jrbyz.cn
http://A0AvIppM.jrbyz.cn
http://bqQ0bVkZ.jrbyz.cn
http://zVG0DyME.jrbyz.cn
http://iQwS2dEp.jrbyz.cn
http://34BHaxon.jrbyz.cn
http://Fzx5d0YV.jrbyz.cn
http://M4Z2Y6au.jrbyz.cn
http://9u3IgE8C.jrbyz.cn
http://FiCF3tZI.jrbyz.cn
http://vocgU78I.jrbyz.cn
http://whl3vkDP.jrbyz.cn
http://uXOiu06H.jrbyz.cn
http://f4heGBqL.jrbyz.cn
http://Vh8sM9o6.jrbyz.cn
http://tw60Gfss.jrbyz.cn
http://g2L0rSvg.jrbyz.cn
http://UhzbjhK0.jrbyz.cn
http://lPUiVB6a.jrbyz.cn
http://S4NHhbe9.jrbyz.cn
http://TpBdeMrT.jrbyz.cn
http://PJ0H3KqJ.jrbyz.cn
http://5cM6EhXG.jrbyz.cn
http://HTMHwWPJ.jrbyz.cn
http://ODNdcRns.jrbyz.cn
http://www.dtcms.com/a/369248.html

相关文章:

  • 【教程】快速入门golang
  • Day21_【机器学习—决策树(2)—ID3树 、C4.5树、CART树】
  • std::complex
  • 深度解读:PSPNet(Pyramid Scene Parsing Network) — 用金字塔池化把“场景理解”装进分割网络
  • 【WRF-Chem】SYNMAP 土地覆盖数据概述及处理(二进制转geotiff)
  • 怎么快速构建一个deep search模型呢
  • Dify基础应用
  • 日语学习-日语知识点小记-构建基础-JLPT-N3阶段(26):文法+单词第8回3 复习 +考え方6
  • Screen 三步上手
  • Pspice仿真电路:(三十六)变压器仿真
  • pydantic定义llm response数据模型
  • 开学信息收集不再愁,这个工具太省心
  • 豆包 arraylist顺序会变么
  • 软考最稳定的一个科目,你认同吗?
  • 【问题解决】mac笔记本遇到鼠标无法点击键盘可响应处理办法?(Command+Option+P+R)
  • 介电常数何解?
  • VMwaer虚拟机安装完Centos后无法联网问题
  • 【阿里存储桶OSS】桶ACL解释
  • Beetle RP2350开发板使用指南之【环境搭建 / 点灯】
  • Y3垂起标准配置文件解析()
  • JSON转义
  • Kaggle - LLM Science Exam 大模型做科学选择题
  • CSS定位与浮动:脱离常规流的艺术
  • C/C++ 与 Lua 互相调用详解
  • mysq集群高可用架构之组复制MGR(单主复制-多主复制)
  • PyInstaller完整指南:将Python程序打包成可执行文件
  • SQL工具30年演进史:从Oracle到Navicat、DBeaver,再到Web原生SQLynx
  • Linux 综合练习
  • 详解iOS应用如何成功上架App Store:从准备到发布与优化
  • 2025.09.05 用队列实现栈 有效的括号 删除字符串中的所有相邻重复项