ResNeXt-50--分组卷积--J6
前言
- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**
本文为个人理解,如有错误,感谢指正
前置代码:
# DenseNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
warnings.filterwarnings('ignore')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device# 提取数据集
data_file=r'./2data_3/data'train_forms=transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(data_file,transform=train_forms)
total_datatotal_data.class_to_idx# 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
# torch.utils.data.random_split()
train_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_datasetbatch_size=4
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True
)for x, y in test_dl:print(x.shape)print(y.shape)break
网络结构1--残差块
残差块是残差神经网络的主要结构,在ResNeXt-500中,将原本的整体的残差块分成每一组的形式,从而实现分组卷积。其网络结构组成首先为一层常规的卷积层,其目的是将通道维数压缩到指定的通道数目中,从而减少后续计算量。第二层开始进行分组卷积,采用3*3的卷积核,对每组进行通道数目不变的卷积,卷积组数为指定的分组数量,在分组卷积过后,采用1*1的卷积核进行升维,将输出通道扩大为原来的两倍,最后进行一层卷积残差连接,将原数据进行1*1的卷积扩大维度,与分层卷积后的结果相加,实现特征的融合。
class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self, in_chnls, cardinality, group_depth, stride):super(ResNeXt_Block, self).__init__()self.group_chnls = cardinality * group_depthself.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)self.bn = nn.BatchNorm2d(self.group_chnls*2)self.short_cut = nn.Sequential(nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),nn.BatchNorm2d(self.group_chnls*2))def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(self.conv3(out))out += self.short_cut(x)return F.relu(out)
整体网络结构:
首先是一层正常的卷积层,利用7*7的大卷积核使得整体通道变为64,然后开始进行分组卷积,其中,分组卷积的实现则是通过__make_layers方法,其中__make_layers方法中,利用循环将每个残差块连接在一起,从而实现了从第二层开始的逐层连接的分层卷积层,conv2,conv3...conv5都是如此。在利用一层7*7大小卷积核的全局池化层将特征图转为1*1的大小,只保留通道特征,从而降低计算量,然后利用全连接层进行分类,输出概率,实现整个网络。
class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:super(ResNeXt, self).__init__()self.cardinality = cardinalityself.channels = 64self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)d1 = group_depthself.conv2 = self.___make_layers(d1, layers[0], stride=1)d2 = d1 * 2self.conv3 = self.___make_layers(d2, layers[1], stride=2)d3 = d2 * 2self.conv4 = self.___make_layers(d3, layers[2], stride=2)d4 = d3 * 2self.conv5 = self.___make_layers(d4, layers[3], stride=2)self.fc = nn.Linear(self.channels, num_classes) # 224x224 input sizedef ___make_layers(self, d, blocks, stride):strides = [stride] + [1] * (blocks-1)layers = []for stride in strides:layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))self.channels = self.cardinality*d*2return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = F.max_pool2d(out, 3, 2, 1)out = self.conv2(out)out = self.conv3(out)out = self.conv4(out)out = self.conv5(out)out = F.avg_pool2d(out, 7)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out),dim=1)return out
网络结构:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 112, 112] 9,408BatchNorm2d-2 [-1, 64, 112, 112] 128BN_Conv2d-3 [-1, 64, 112, 112] 0Conv2d-4 [-1, 128, 56, 56] 8,192BatchNorm2d-5 [-1, 128, 56, 56] 256BN_Conv2d-6 [-1, 128, 56, 56] 0Conv2d-7 [-1, 128, 56, 56] 4,608BatchNorm2d-8 [-1, 128, 56, 56] 256BN_Conv2d-9 [-1, 128, 56, 56] 0Conv2d-10 [-1, 256, 56, 56] 33,024BatchNorm2d-11 [-1, 256, 56, 56] 512Conv2d-12 [-1, 256, 56, 56] 16,384BatchNorm2d-13 [-1, 256, 56, 56] 512ResNeXt_Block-14 [-1, 256, 56, 56] 0Conv2d-15 [-1, 128, 56, 56] 32,768BatchNorm2d-16 [-1, 128, 56, 56] 256BN_Conv2d-17 [-1, 128, 56, 56] 0Conv2d-18 [-1, 128, 56, 56] 4,608BatchNorm2d-19 [-1, 128, 56, 56] 256BN_Conv2d-20 [-1, 128, 56, 56] 0Conv2d-21 [-1, 256, 56, 56] 33,024BatchNorm2d-22 [-1, 256, 56, 56] 512Conv2d-23 [-1, 256, 56, 56] 65,536BatchNorm2d-24 [-1, 256, 56, 56] 512ResNeXt_Block-25 [-1, 256, 56, 56] 0Conv2d-26 [-1, 128, 56, 56] 32,768BatchNorm2d-27 [-1, 128, 56, 56] 256BN_Conv2d-28 [-1, 128, 56, 56] 0Conv2d-29 [-1, 128, 56, 56] 4,608BatchNorm2d-30 [-1, 128, 56, 56] 256BN_Conv2d-31 [-1, 128, 56, 56] 0Conv2d-32 [-1, 256, 56, 56] 33,024BatchNorm2d-33 [-1, 256, 56, 56] 512Conv2d-34 [-1, 256, 56, 56] 65,536BatchNorm2d-35 [-1, 256, 56, 56] 512ResNeXt_Block-36 [-1, 256, 56, 56] 0Conv2d-37 [-1, 256, 56, 56] 65,536BatchNorm2d-38 [-1, 256, 56, 56] 512BN_Conv2d-39 [-1, 256, 56, 56] 0Conv2d-40 [-1, 256, 28, 28] 18,432BatchNorm2d-41 [-1, 256, 28, 28] 512BN_Conv2d-42 [-1, 256, 28, 28] 0Conv2d-43 [-1, 512, 28, 28] 131,584BatchNorm2d-44 [-1, 512, 28, 28] 1,024Conv2d-45 [-1, 512, 28, 28] 131,072BatchNorm2d-46 [-1, 512, 28, 28] 1,024ResNeXt_Block-47 [-1, 512, 28, 28] 0Conv2d-48 [-1, 256, 28, 28] 131,072BatchNorm2d-49 [-1, 256, 28, 28] 512BN_Conv2d-50 [-1, 256, 28, 28] 0Conv2d-51 [-1, 256, 28, 28] 18,432BatchNorm2d-52 [-1, 256, 28, 28] 512BN_Conv2d-53 [-1, 256, 28, 28] 0Conv2d-54 [-1, 512, 28, 28] 131,584BatchNorm2d-55 [-1, 512, 28, 28] 1,024Conv2d-56 [-1, 512, 28, 28] 262,144BatchNorm2d-57 [-1, 512, 28, 28] 1,024ResNeXt_Block-58 [-1, 512, 28, 28] 0Conv2d-59 [-1, 256, 28, 28] 131,072BatchNorm2d-60 [-1, 256, 28, 28] 512BN_Conv2d-61 [-1, 256, 28, 28] 0Conv2d-62 [-1, 256, 28, 28] 18,432BatchNorm2d-63 [-1, 256, 28, 28] 512BN_Conv2d-64 [-1, 256, 28, 28] 0Conv2d-65 [-1, 512, 28, 28] 131,584BatchNorm2d-66 [-1, 512, 28, 28] 1,024Conv2d-67 [-1, 512, 28, 28] 262,144BatchNorm2d-68 [-1, 512, 28, 28] 1,024ResNeXt_Block-69 [-1, 512, 28, 28] 0Conv2d-70 [-1, 256, 28, 28] 131,072BatchNorm2d-71 [-1, 256, 28, 28] 512BN_Conv2d-72 [-1, 256, 28, 28] 0Conv2d-73 [-1, 256, 28, 28] 18,432BatchNorm2d-74 [-1, 256, 28, 28] 512BN_Conv2d-75 [-1, 256, 28, 28] 0Conv2d-76 [-1, 512, 28, 28] 131,584BatchNorm2d-77 [-1, 512, 28, 28] 1,024Conv2d-78 [-1, 512, 28, 28] 262,144BatchNorm2d-79 [-1, 512, 28, 28] 1,024ResNeXt_Block-80 [-1, 512, 28, 28] 0Conv2d-81 [-1, 512, 28, 28] 262,144BatchNorm2d-82 [-1, 512, 28, 28] 1,024BN_Conv2d-83 [-1, 512, 28, 28] 0Conv2d-84 [-1, 512, 14, 14] 73,728BatchNorm2d-85 [-1, 512, 14, 14] 1,024BN_Conv2d-86 [-1, 512, 14, 14] 0Conv2d-87 [-1, 1024, 14, 14] 525,312BatchNorm2d-88 [-1, 1024, 14, 14] 2,048Conv2d-89 [-1, 1024, 14, 14] 524,288BatchNorm2d-90 [-1, 1024, 14, 14] 2,048ResNeXt_Block-91 [-1, 1024, 14, 14] 0Conv2d-92 [-1, 512, 14, 14] 524,288BatchNorm2d-93 [-1, 512, 14, 14] 1,024BN_Conv2d-94 [-1, 512, 14, 14] 0Conv2d-95 [-1, 512, 14, 14] 73,728BatchNorm2d-96 [-1, 512, 14, 14] 1,024BN_Conv2d-97 [-1, 512, 14, 14] 0Conv2d-98 [-1, 1024, 14, 14] 525,312BatchNorm2d-99 [-1, 1024, 14, 14] 2,048Conv2d-100 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-101 [-1, 1024, 14, 14] 2,048ResNeXt_Block-102 [-1, 1024, 14, 14] 0Conv2d-103 [-1, 512, 14, 14] 524,288BatchNorm2d-104 [-1, 512, 14, 14] 1,024BN_Conv2d-105 [-1, 512, 14, 14] 0Conv2d-106 [-1, 512, 14, 14] 73,728BatchNorm2d-107 [-1, 512, 14, 14] 1,024BN_Conv2d-108 [-1, 512, 14, 14] 0Conv2d-109 [-1, 1024, 14, 14] 525,312BatchNorm2d-110 [-1, 1024, 14, 14] 2,048Conv2d-111 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-112 [-1, 1024, 14, 14] 2,048ResNeXt_Block-113 [-1, 1024, 14, 14] 0Conv2d-114 [-1, 512, 14, 14] 524,288BatchNorm2d-115 [-1, 512, 14, 14] 1,024BN_Conv2d-116 [-1, 512, 14, 14] 0Conv2d-117 [-1, 512, 14, 14] 73,728BatchNorm2d-118 [-1, 512, 14, 14] 1,024BN_Conv2d-119 [-1, 512, 14, 14] 0Conv2d-120 [-1, 1024, 14, 14] 525,312BatchNorm2d-121 [-1, 1024, 14, 14] 2,048Conv2d-122 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-123 [-1, 1024, 14, 14] 2,048ResNeXt_Block-124 [-1, 1024, 14, 14] 0Conv2d-125 [-1, 512, 14, 14] 524,288BatchNorm2d-126 [-1, 512, 14, 14] 1,024BN_Conv2d-127 [-1, 512, 14, 14] 0Conv2d-128 [-1, 512, 14, 14] 73,728BatchNorm2d-129 [-1, 512, 14, 14] 1,024BN_Conv2d-130 [-1, 512, 14, 14] 0Conv2d-131 [-1, 1024, 14, 14] 525,312BatchNorm2d-132 [-1, 1024, 14, 14] 2,048Conv2d-133 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-134 [-1, 1024, 14, 14] 2,048ResNeXt_Block-135 [-1, 1024, 14, 14] 0Conv2d-136 [-1, 512, 14, 14] 524,288BatchNorm2d-137 [-1, 512, 14, 14] 1,024BN_Conv2d-138 [-1, 512, 14, 14] 0Conv2d-139 [-1, 512, 14, 14] 73,728BatchNorm2d-140 [-1, 512, 14, 14] 1,024BN_Conv2d-141 [-1, 512, 14, 14] 0Conv2d-142 [-1, 1024, 14, 14] 525,312BatchNorm2d-143 [-1, 1024, 14, 14] 2,048Conv2d-144 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-145 [-1, 1024, 14, 14] 2,048ResNeXt_Block-146 [-1, 1024, 14, 14] 0Conv2d-147 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-148 [-1, 1024, 14, 14] 2,048BN_Conv2d-149 [-1, 1024, 14, 14] 0Conv2d-150 [-1, 1024, 7, 7] 294,912BatchNorm2d-151 [-1, 1024, 7, 7] 2,048BN_Conv2d-152 [-1, 1024, 7, 7] 0Conv2d-153 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-154 [-1, 2048, 7, 7] 4,096Conv2d-155 [-1, 2048, 7, 7] 2,097,152BatchNorm2d-156 [-1, 2048, 7, 7] 4,096ResNeXt_Block-157 [-1, 2048, 7, 7] 0Conv2d-158 [-1, 1024, 7, 7] 2,097,152BatchNorm2d-159 [-1, 1024, 7, 7] 2,048BN_Conv2d-160 [-1, 1024, 7, 7] 0Conv2d-161 [-1, 1024, 7, 7] 294,912BatchNorm2d-162 [-1, 1024, 7, 7] 2,048BN_Conv2d-163 [-1, 1024, 7, 7] 0Conv2d-164 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-165 [-1, 2048, 7, 7] 4,096Conv2d-166 [-1, 2048, 7, 7] 4,194,304BatchNorm2d-167 [-1, 2048, 7, 7] 4,096ResNeXt_Block-168 [-1, 2048, 7, 7] 0Conv2d-169 [-1, 1024, 7, 7] 2,097,152BatchNorm2d-170 [-1, 1024, 7, 7] 2,048BN_Conv2d-171 [-1, 1024, 7, 7] 0Conv2d-172 [-1, 1024, 7, 7] 294,912BatchNorm2d-173 [-1, 1024, 7, 7] 2,048BN_Conv2d-174 [-1, 1024, 7, 7] 0Conv2d-175 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-176 [-1, 2048, 7, 7] 4,096Conv2d-177 [-1, 2048, 7, 7] 4,194,304BatchNorm2d-178 [-1, 2048, 7, 7] 4,096ResNeXt_Block-179 [-1, 2048, 7, 7] 0Linear-180 [-1, 2] 4,098
================================================================
Total params: 37,570,626
Trainable params: 37,570,626
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.37
Params size (MB): 143.32
Estimated Total Size (MB): 523.26
----------------------------------------------------------------
训练代码:
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_acc,train_loss=0,0for x,y in dataloader:x,y=x.to(device),y.to(device)pred=model(x)loss=loss_fn(pred,y) #计算损失 loss_fn(预测值,真实值)optimizer.zero_grad() #清除梯度loss.backward() #反向传播 计算梯度optimizer.step() #梯度自动更新train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item() #在预测过后会,每一个严格样本会形成对于三种类别的概率,然后取最高的为当前预测值,累加# 假设的 pred(每行是一个样本的类别概率,概率和为1)# pred = torch.tensor([# [0.1, 0.2, 0.7], # 第1个样本:类别2的概率最高# [0.05, 0.15, 0.8], # 第2个样本:类别2的概率最高# [0.3, 0.6, 0.1], # 第3个样本:类别1的概率最高# [0.9, 0.08, 0.02] # 第4个样本:类别0的概率最高# ])train_loss+=loss.item()train_acc/=size #train_acc加的是每一个的值所以要除全部train_loss/=num_batches #加的是批次的值,除以批次return train_acc,train_lossdef test(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)test_acc,test_loss=0,0with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)pred=model(imgs)loss=loss_fn(pred,target)test_loss+=loss.item()test_acc+=(pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
import copy
import torch, gcoptimizer=torch.optim.AdamW(model.parameters(),lr=1e-4) #优化器
loss_fn=nn.CrossEntropyLoss() #交叉损失函数
epochs=30
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #最好的结果
for epoch in range(epochs):model.train() #训练模式epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)model.eval() #测试模式epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn,optimizer)if epoch_test_acc>best_acc:best_acc=epoch_test_accbest_model=copy.deepcopy(model)#复制模型train_acc.append(epoch_train_acc)test_acc.append(epoch_test_acc)train_loss.append(epoch_train_loss)test_loss.append(epoch_test_loss)lr=optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))gc.collect()torch.cuda.empty_cache()
print('Done')
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()