当前位置: 首页 > news >正文

使用pytorch创建模型时,nn.BatchNorm1d(128)的作用是什么?

在PyTorch中,nn.BatchNorm1d(128) 的作用是对 一维输入数据(如全连接层的输出或时间序列数据)进行批标准化(Batch Normalization),具体功能与实现原理如下:

1. 核心作用

  • 标准话数据分布
    对每个批次的输入数据进行归一化,使其均值接近0、方差接近1,公式如下:
    x^=x−μbatchσbatch2+e\hat{\mathbf{x}}=\frac{\mathbf{x}-\mathbf{\mu}_{batch}}{\sqrt{\sigma^{2}_{batch}+e}}x^=σbatch2+exμbatch
    其中:
    • μbatch\mu_{batch}μbatch:当前批次的均值
    • σbatch\sigma_{batch}σbatch:当前批次的方差
    • eee: 防止除零的小常数(默认1e-5)
  • 可学习的缩放与偏移:
    通过参数γ\gammaγ (缩放)和 β\betaβ(偏移)保留模型的表达能力:
    y=γx^+β y = \gamma \hat{\mathbf{x}}+\beta y=γx^+β

2. 参数解释

在这里插入图片描述

3. 全连接网络应用场景

import torch.nn as nnmodel = nn.Sequential(nn.Linear(64, 128),nn.BatchNorm1d(128),  # 对128维特征归一化nn.ReLU(),nn.Linear(128, 10)
)

数学效果:
若输入特征x∈Rm×128\mathbf{x}\in \mathbb{R}^{m\times128}xRm×128,输出yyy满足:
E[y:j]≈0,Var(y:,j)≈1 \mathbb{E}[y_{:j}]\approx0, Var(y_{:,j})\approx1 E[y:j]0,Var(y:,j)1

4. 与其他归一化层的对比

在这里插入图片描述

5. 训练与推理的差异

  • 训练阶段
    使用当前批次的统计量μbatch\mu_{batch}μbatchσbatch2\sigma_{batch}^2σbatch2,并更新全局统计量:
    μrunnning←μrunning×(1−momentum)+μbatch×momentum\mu_{runnning} \leftarrow \mu_{running}\times(1-momentum) + \mu_{batch}\times momentumμrunnningμrunning×(1momentum)+μbatch×momentum
  • 推理阶段(测试阶段)
    固定使用训练积累的全局统计量μbatch\mu_{batch}μbatchσbatch2\sigma_{batch}^2σbatch2
    KaTeX parse error: Undefined control sequence: \sigmma at position 54: …unning}}{\sqrt{\̲s̲i̲g̲m̲m̲a̲^{2}_{running}+…

6. 代码战争数学性质

import torch# 模拟输入(batch_size=4, 128维特征)
x = torch.randn(4, 128) * 2 + 1  # 均值1,方差4bn = nn.BatchNorm1d(128, affine=False)  # 禁用γ和β
output = bn(x)print("输入均值:", x.mean(dim=0).mean().item())   # ≈1
print("输出均值:", output.mean(dim=0).mean().item())  # ≈0
print("输入方差:", x.var(dim=0).mean().item())    # ≈4
print("输出方差:", output.var(dim=0).mean().item())  # ≈1
http://www.dtcms.com/a/289139.html

相关文章:

  • gradle关于dependency-management的使用
  • SpringBoot 整合 Langchain4j 实现会话记忆存储深度解析
  • OpenCV 入门知识:图片展示、摄像头捕获、控制鼠标及其 Trackbar(滑动条)生成!
  • 【LeetCode刷题指南】--反转链表,链表的中间结点,合并两个有序链表
  • Day25| 491.递增子序列、46.全排列
  • SQL Server(2022)安装教程及使用_sqlserver下载安装图文
  • redis-plus-plus安装与使用
  • [BUG]关于UE5.6编译时出现“Microsoft.MakeFile.Targets(44,5): Error MSB3073”问题的解决
  • 30天打牢数模基础-SVM讲解
  • Facebook 开源多季节性时间序列数据预测工具:Prophet 快速入门 Quick Start
  • UE5多人MOBA+GAS 26、为角色添加每秒回血回蓝(番外:添加到UI上)
  • Go并发聊天室:从零构建实战
  • Mysql(事务)
  • 30个常用的Linux命令汇总和实战场景示例
  • 30天打牢数模基础-粒子群算法讲解
  • 详解Mysql索引合并
  • Jetpack - ViewModel、LiveData、DataBinding(数据绑定、双向数据绑定)
  • langchain调用本地ollama语言模型和嵌入模型
  • 梯度提升之原理
  • COGNEX康耐视IS5403-01智能相机加Navitar 18R00 LR1010WM52镜头
  • React 英语打地鼠游戏——一个寓教于乐的英语学习游戏
  • [Windows] Bili视频转图文笔记 v1.7.5
  • 网鼎杯2020青龙组notes复现
  • 7. 命令模式
  • Modbus Slave 使用教程:快速搭建模拟从站进行测试与开发
  • Ribbon轮询实现原理
  • Unity笔记——Unity 封装方法指南
  • day24——Java高级技术深度解析:单元测试、反射、注解与动态代理
  • [Python] -项目实战类3- 用Python制作一个记事本应用
  • CVE-2022-41128