当前位置: 首页 > news >正文

卷积神经网络CNN-part7-批量规范化BatchNorm

卷积神经网络CNN-part6-GoogLeNet-CSDN博客

摘要:批量规范化(Batch Normalization)是一种加速深层神经网络训练的有效技术。文章详细介绍了在全连接层和卷积层中应用批量规范化的方法,包括其数学公式和实现过程。通过比较训练模式和预测模式下的不同计算方式,说明了批量规范化的运作机制。文章提供了使用PyTorch实现批量规范化层的代码示例,并将其应用于LeNet网络结构中,展示了在Fashion-MNIST数据集上的训练效果。最后还演示了如何直接使用PyTorch内置的批量规范化模块(BatchNorm1d和BatchNorm2d)来简化网络构建过。【AI】

 1批量规范化(batch normalization)

        训练深层神经网络是十分困难的,特别是较短时间内使它们收敛更加棘手。批量规范化(batch normalization)是一种有效的可持续加速深层网络收敛的方法。这里我们讨论全连接层和卷积层的批量规范化。

 1.1全连接层

        通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。设输入为x,权重参数W,偏置参数为b,激活函数为,批量规范化为BN,输出公式为:

h=\phi (BN(Wx+b))

1.2卷积层

        卷积层,我们再卷积层之后和非线性激活函数之前应用批量规范化。当有多个通道时,对每个通道的输出进行批量规范化。小批量包含m个样本,对每个通道输出有高度p和宽度q,则我们在每个输出通道进行m·p·q个元素上同时进行批量规范化。

1.3预测过程中的批量规范化

        批量规范化层在训练模式和预测模式下的计算结果不一样。

2.批量规范化实现

2.1 创建BatchNorm层

import torch
import torch.nn as nn
from d2l import torch as d2ldef batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):#通过is_grad_enabled方法来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():#如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape) == 2:#使用全连接层情况,计算特征维上的均值和方差mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:#使用二维卷积层情况,计算通道维上(axis=1)的均值和方差#需要保持X的形状以便后面可以做广播运算mean=X.mean(dim=(0,2,3),keepdims=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdims=True)#训练模式下,用当前的均值和方差做标准化X_hat=(X-mean)/torch.sqrt(var+eps)#更新移动平均的均值和方差moving_mean=momentum*moving_mean+(1-momentum)*meanmoving_var=momentum*moving_var+(1-momentum)*varY=gamma*X_hat+beta#缩放和移位return Y,moving_mean.data,moving_var.data

        我们这里构建一个BatchNorm层的类

class BatchNorm(nn.Module):#num_features全连接层的输出数量或卷积层的输出通道数#num_dims:2表示完全连接层,4表示卷积层def __init__(self,num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#参与求梯度和迭代的拉伸参数和偏移参数,其分别初始化成1和0self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))#非模型参数的变量初始化为0和1self.moving_mean=torch.zeros(shape)self.moving_var=torch.zeros(shape)def forward(self, x):#如果x不在内存上,将moving_mean和moving_var复制到X所在的显存上if self.moving_mean.device!=x.device:self.moving_mean=self.moving_mean.to(x.device)self.moving_var=self.moving_var.to(x.device)#保存更新过的y,self.moving_mean,self.moving_var=batch_norm(x,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return y

2.2 在LeNet中使用

2.2.1构建网络

#使用批量规范化层的LeNet
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,4),nn.Sigmoid(),nn.AvgPool2d(2,2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,4),nn.Sigmoid(),nn.AvgPool2d(2,2),nn.Flatten(),nn.Linear(16*4*4,120),BatchNorm(120,2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,2),nn.Sigmoid(),nn.Linear(84,10))

2.2.2训练

lr,num_epochs,batch_size=1.0,10,256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())net[1].gamma.reshape((-1,)),net[1].beta.reshape((-1,))

结果:

(tensor([1.5512, 1.3251, 2.9371, 1.5447, 5.2065, 4.7697], grad_fn=<ViewBackward0>),
 tensor([-0.7209, -3.0481,  3.3563, -2.6427,  2.8693, -4.5432], grad_fn=<ViewBackward0>))

3.使用torch中的批量规范化

3.1网络构建

net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),nn.BatchNorm2d(6),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*4*4,120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120,84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84,10))

3.2训练

d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

结果:

http://www.dtcms.com/a/390299.html

相关文章:

  • [xboard]02 uboot下载、移植、编译概述
  • Python入门教程之字符串运算
  • 堡垒机部署
  • 刷题记录(10)stack和queue的简单应用
  • 如何进行时间管理?
  • Spring面试题及详细答案 125道(46-65) -- 事务管理
  • OA ⇄ CRM 单点登录(SSO)实现说明
  • 人工智能在设备管理软件中的应用
  • __pycache__ 文件夹作用
  • 利欧泵业数据中心液冷系统解决方案亮相2025 ODCC开放数据中心峰会
  • 【论文阅读】Masked Conditional Variational Autoencoders for Chromosome Straightening
  • 天气预测:AI 如何为我们 “算” 出未来的天空?
  • 大数据管理与应用有什么注意事项?企业该如何发挥大数据的价值
  • CSS的opacity 属性
  • STM32 LwIP协议栈优化:从TCP延迟10ms降至1ms的内存配置手册
  • 【0基础3ds Max】创建标准基本体(长方体、球体、圆柱体等)理论
  • 驾驭未来:深度体验 Flet 0.7.0 的重大变革与服务化架构
  • 【Datawhale组队学习202509】AI硬件与机器人大模型 task01 具身智能基础
  • Go语言高并发编程全面解析:从基础到高级实战
  • leetcode算法刷题的第三十八天
  • RHEL 兼容发行版核心对比表
  • 如何解决 pip install 安装报错 ModuleNotFoundError: No module named ‘yaml’ 问题
  • 无刷电机有感方波闭环控制
  • 【EKL】
  • 设计模式-模板方法模式详解(2)
  • 算法(一)双指针法
  • C语言指针深度解析:从核心原理到工程实践
  • hsahmap的寻址算法和为是你扩容为2的N次方
  • ​​[硬件电路-243]:电源纹波与噪声
  • Kurt-Blender零基础教程:第1章:基础篇——第2节:认识界面