当前位置: 首页 > news >正文

BN 层做预测的时候, 方差均值怎么算

✅ 一、Batch Normalization(BN)回顾

 

BN 层在训练和推理阶段的行为是不一样的,核心区别就在于:

训练时用 mini-batch 里的均值方差,预测时用全局的“滑动平均”均值方差。

🧪 二、训练阶段(Training mode)

• 每个小批量(batch)都会计算:

\mu_{\text{batch}} = \frac{1}{m} \sum_{i=1}^{m} x_i

\sigma^2_{\text{batch}} = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\text{batch}})^2

\hat{x}_i = \frac{x_i - \mu_{\text{batch}}}{\sqrt{\sigma^2_{\text{batch}} + \epsilon}}

为了后面预测用得上,训练时还会维护全局“滑动平均”:
\mu_{\text{running}} = \rho \cdot \mu_{\text{running}} + (1 - \rho) \cdot \mu_{\text{batch}} \sigma^2_{\text{running}} = \rho \cdot \sigma^2_{\text{running}} + (1 - \rho) \cdot \sigma^2_{\text{batch}}

其中\rho是动量参数(momentum),通常为 0.9 或 0.99。


🧠 三、推理阶段(Evaluation / Inference)

推理阶段不会再计算当前 batch 的均值和方差。

而是使用训练时积累的滑动平均

\hat{x}i = \frac{x_i - \mu{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}}

这样能保证预测过程中结果稳定、不依赖 batch 大小或数据分布波动


🧰 四、在 PyTorch / TensorFlow 中自动切换

PyTorch:

model.train()   # 启用训练模式,BN 用 batch 均值方差
model.eval()    # 启用评估模式,BN 用滑动均值方差

TensorFlow (Keras):

model.fit(...)       # 自动使用训练模式
model.evaluate(...)  # 自动使用推理模式

📌 总结一句话:

BN 层预测时的均值和方差,来自 训练期间累计的滑动平均值,而不是实时计算。

五、补充知识:Keras 是什么


🧠 一句话定义:

Keras 是一个高级神经网络 API,用来快速搭建、训练和部署深度学习模型,底层运行在 TensorFlow 上。

📦 二、Keras 的定位

特性

说明

高级封装

用几行代码就能搭建复杂模型,适合快速开发

基于 TensorFlow

现在是 TensorFlow 的官方高层 API(tf.keras)

易学易用

类似积木式的拼接方式,语法简洁,初学者友好

灵活性强

同时支持顺序模型(Sequential)和函数式模型(Functional API)

支持多种任务

图像分类、NLP、生成模型、时间序列、强化学习等

支持多平台部署

可以导出为 SavedModel,支持 TensorFlow Serving、TFLite、ONNX、Web 等


⚙️ 三、简单例子(Keras 搭建一个 MLP 分类器)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([
    Dense(64, activation='relu', input_shape=(100,)),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 假设 x_train.shape = (1000, 100),y_train 是 one-hot 标签
model.fit(x_train, y_train, epochs=10, batch_size=32)

🏗 四、Keras 模型的两种构建方式

1. Sequential(顺序模型)

• 一层接一层,简单好用

model = Sequential([...])

2. Functional API(函数式模型)

• 灵活连接,适合多输入/多输出、残差连接等复杂结构

from tensorflow.keras import Model, Input
x = Input(shape=(100,))
h = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(h)
model = Model(inputs=x, outputs=y)

🔥 五、Keras 常见模块

模块

作用

tf.keras.models

创建模型(Sequential、Model)

tf.keras.layers

各种神经网络层(Dense、Conv2D、LSTM 等)

tf.keras.optimizers

优化器(SGD、Adam、RMSprop 等)

tf.keras.losses

损失函数(MSE、CrossEntropy 等)

tf.keras.metrics

评价指标(Accuracy、Precision 等)

tf.keras.callbacks

回调函数(EarlyStopping、ModelCheckpoint 等)


📌 总结一句话:

Keras = 深度学习“乐高”,用来快速搭建模型,适合初学者,也支持复杂自定义模型,是 TensorFlow 的核心部分。

相关文章:

  • c++的map基本知识
  • Hyperlane框架全面详解与应用指南 [特殊字符][特殊字符][特殊字符]
  • React 初学者进阶指南:从环境搭建到部署上线
  • stc8g1k08a adc采集电压输出到串口和屏幕
  • 深入理解 QScrollArea 的 widgetResizable 属性
  • C++——静态成员
  • flutter 专题 六十八 Flutter 多图片上传
  • C++:函数
  • AF3 OpenFoldDataLoader类解读
  • PostgreSQL 一文从安装到入门掌握基本应用开发能力!
  • 【C++】--- string的使用
  • go游戏后端开发24:写完赢三张游戏
  • C++中如何使用STL中的list定义一个双向链表,并且实现增、删、改、查操作
  • #SVA语法滴水穿石# (012)关于 first_match、throughout、within 的用法
  • 华为交换机配置指南:基础到高级命令详解
  • 51单片机使用定时器实现LCD1602的时间显示(STC89C52RC)
  • 迭代器运算详解(四十二)
  • OSI模型中协议数据单元(PDU)
  • 21 天 Python 计划:MySQL库相关操作
  • 深信服护网蓝初面试题
  • 网站搜索排优化怎么做/微营销
  • wordpress调用指定文章id/长沙优化网站推广
  • 做网站找雷鸣/万能搜索引擎网站
  • 做程序的网站/2023广州疫情最新消息今天
  • 甘肃网站建设哪家便宜/百度网站收录提交
  • 电子商务网站建设的一般/免费外链发布平台在线