卷积神经网络CNN-part7-批量规范化BatchNorm
卷积神经网络CNN-part6-GoogLeNet-CSDN博客
摘要:批量规范化(Batch Normalization)是一种加速深层神经网络训练的有效技术。文章详细介绍了在全连接层和卷积层中应用批量规范化的方法,包括其数学公式和实现过程。通过比较训练模式和预测模式下的不同计算方式,说明了批量规范化的运作机制。文章提供了使用PyTorch实现批量规范化层的代码示例,并将其应用于LeNet网络结构中,展示了在Fashion-MNIST数据集上的训练效果。最后还演示了如何直接使用PyTorch内置的批量规范化模块(BatchNorm1d和BatchNorm2d)来简化网络构建过。【AI】
1批量规范化(batch normalization)
训练深层神经网络是十分困难的,特别是较短时间内使它们收敛更加棘手。批量规范化(batch normalization)是一种有效的可持续加速深层网络收敛的方法。这里我们讨论全连接层和卷积层的批量规范化。
1.1全连接层
通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。设输入为x,权重参数W,偏置参数为b,激活函数为,批量规范化为BN,输出公式为:
1.2卷积层
卷积层,我们再卷积层之后和非线性激活函数之前应用批量规范化。当有多个通道时,对每个通道的输出进行批量规范化。小批量包含m个样本,对每个通道输出有高度p和宽度q,则我们在每个输出通道进行m·p·q个元素上同时进行批量规范化。
1.3预测过程中的批量规范化
批量规范化层在训练模式和预测模式下的计算结果不一样。
2.批量规范化实现
2.1 创建BatchNorm层
import torch
import torch.nn as nn
from d2l import torch as d2ldef batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):#通过is_grad_enabled方法来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():#如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape) == 2:#使用全连接层情况,计算特征维上的均值和方差mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:#使用二维卷积层情况,计算通道维上(axis=1)的均值和方差#需要保持X的形状以便后面可以做广播运算mean=X.mean(dim=(0,2,3),keepdims=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdims=True)#训练模式下,用当前的均值和方差做标准化X_hat=(X-mean)/torch.sqrt(var+eps)#更新移动平均的均值和方差moving_mean=momentum*moving_mean+(1-momentum)*meanmoving_var=momentum*moving_var+(1-momentum)*varY=gamma*X_hat+beta#缩放和移位return Y,moving_mean.data,moving_var.data
我们这里构建一个BatchNorm层的类
class BatchNorm(nn.Module):#num_features全连接层的输出数量或卷积层的输出通道数#num_dims:2表示完全连接层,4表示卷积层def __init__(self,num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#参与求梯度和迭代的拉伸参数和偏移参数,其分别初始化成1和0self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))#非模型参数的变量初始化为0和1self.moving_mean=torch.zeros(shape)self.moving_var=torch.zeros(shape)def forward(self, x):#如果x不在内存上,将moving_mean和moving_var复制到X所在的显存上if self.moving_mean.device!=x.device:self.moving_mean=self.moving_mean.to(x.device)self.moving_var=self.moving_var.to(x.device)#保存更新过的y,self.moving_mean,self.moving_var=batch_norm(x,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return y
2.2 在LeNet中使用
2.2.1构建网络
#使用批量规范化层的LeNet
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,4),nn.Sigmoid(),nn.AvgPool2d(2,2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,4),nn.Sigmoid(),nn.AvgPool2d(2,2),nn.Flatten(),nn.Linear(16*4*4,120),BatchNorm(120,2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,2),nn.Sigmoid(),nn.Linear(84,10))
2.2.2训练
lr,num_epochs,batch_size=1.0,10,256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())net[1].gamma.reshape((-1,)),net[1].beta.reshape((-1,))
结果:
(tensor([1.5512, 1.3251, 2.9371, 1.5447, 5.2065, 4.7697], grad_fn=<ViewBackward0>),
tensor([-0.7209, -3.0481, 3.3563, -2.6427, 2.8693, -4.5432], grad_fn=<ViewBackward0>))
3.使用torch中的批量规范化
3.1网络构建
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),nn.BatchNorm2d(6),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.BatchNorm2d(16),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*4*4,120),nn.BatchNorm1d(120),nn.Sigmoid(),nn.Linear(120,84),nn.BatchNorm1d(84),nn.Sigmoid(),nn.Linear(84,10))
3.2训练
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
结果: