VGG模型结构体及代码
VGG模型
2. 用2次3*3卷积代替5*5卷积,3次3*3卷积代替7*7卷积。
(1)卷积结果对应原图的感受野计算:
(F(i)表示第i层的感受野,如果是最上层的就是1,如果多次卷积就多次迭代计算。)
例如:三次3*3卷积代替7*7卷积
F=1, F3=(1-1)*1+3=3, F2=(3-1)*1+3=5, F1=(5-1)*1+3=7
(2) 参数量对比
① 7*7: 7*7*channels*channels=49*C*C
② 3个3*3 3*3*channels*channels*3=27*C*C
3. 模型代码
import torch.nn as nn
import torchclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False): # features: 由make_features生成的提取特征网络结构super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential( # 最后的三层全连接层 (分类网络结构)nn.Dropout(p=0.5), # 与全连接层连接之前,先展平为1维,为了减少过拟合进行dropout再与全连接层进行连接(以0.5的比例随机失活神经元)nn.Linear(512*7*7, 2048), # 原论文中的节点个数是4096,这里简化为2048nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x) # 进入卷积层提取特征# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1) # 展平(第0个维度是batch,所以从第一个维度展平)# N x 512*7*7x = self.classifier(x) # 全连接层进行分类return xdef _initialize_weights(self): # 初始化权重for m in self.modules(): # 遍历网络的每一层if isinstance(m, nn.Conv2d): # 如果当前层是卷积层# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight) # 初始化卷积核的权重if m.bias is not None: # 如果采用了bias,则将bias初始化为0nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear): # 当前层是全连接层nn.init.xavier_uniform_(m.weight) # 初始化全连接层的权重# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 生成提取特征网络结构
def make_features(cfg: list): # 传入含有网络信息的列表layers = []in_channels = 3 # R G Bfor v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers) # 将列表通过非关键字参数的形式传入cfgs = {# 卷积核大小3*3# 数字表示卷积核个数,‘M’表示maxpooling'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}# 实例化VGG网络
def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("Warning: model number {} not in cfgs dict!".format(model_name))exit(-1)model = VGG(make_features(cfg), **kwargs)return model
4. VGG实验源码
注意:使用VGG时,如果使用迁移学习的方法对VGG进行预训练时需要在RGB三个通道减去[123.68,116.78,103.94],如果从头训练则可以忽略。