当前位置: 首页 > news >正文

ResNeXt-50--分组卷积--J6

前言

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

本文为个人理解,如有错误,感谢指正

前置代码:

# DenseNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
warnings.filterwarnings('ignore')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device# 提取数据集
data_file=r'./2data_3/data'train_forms=transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(data_file,transform=train_forms)
total_datatotal_data.class_to_idx# 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
# torch.utils.data.random_split()
train_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_datasetbatch_size=4
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True
)for x, y in test_dl:print(x.shape)print(y.shape)break

网络结构1--残差块

残差块是残差神经网络的主要结构,在ResNeXt-500中,将原本的整体的残差块分成每一组的形式,从而实现分组卷积。其网络结构组成首先为一层常规的卷积层,其目的是将通道维数压缩到指定的通道数目中,从而减少后续计算量。第二层开始进行分组卷积,采用3*3的卷积核,对每组进行通道数目不变的卷积,卷积组数为指定的分组数量,在分组卷积过后,采用1*1的卷积核进行升维,将输出通道扩大为原来的两倍,最后进行一层卷积残差连接,将原数据进行1*1的卷积扩大维度,与分层卷积后的结果相加,实现特征的融合。

class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self, in_chnls, cardinality, group_depth, stride):super(ResNeXt_Block, self).__init__()self.group_chnls = cardinality * group_depthself.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)self.bn = nn.BatchNorm2d(self.group_chnls*2)self.short_cut = nn.Sequential(nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),nn.BatchNorm2d(self.group_chnls*2))def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(self.conv3(out))out += self.short_cut(x)return F.relu(out)

整体网络结构:

首先是一层正常的卷积层,利用7*7的大卷积核使得整体通道变为64,然后开始进行分组卷积,其中,分组卷积的实现则是通过__make_layers方法,其中__make_layers方法中,利用循环将每个残差块连接在一起,从而实现了从第二层开始的逐层连接的分层卷积层,conv2,conv3...conv5都是如此。在利用一层7*7大小卷积核的全局池化层将特征图转为1*1的大小,只保留通道特征,从而降低计算量,然后利用全连接层进行分类,输出概率,实现整个网络。

 
class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:super(ResNeXt, self).__init__()self.cardinality = cardinalityself.channels = 64self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)d1 = group_depthself.conv2 = self.___make_layers(d1, layers[0], stride=1)d2 = d1 * 2self.conv3 = self.___make_layers(d2, layers[1], stride=2)d3 = d2 * 2self.conv4 = self.___make_layers(d3, layers[2], stride=2)d4 = d3 * 2self.conv5 = self.___make_layers(d4, layers[3], stride=2)self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input sizedef ___make_layers(self, d, blocks, stride):strides = [stride] + [1] * (blocks-1)layers = []for stride in strides:layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))self.channels = self.cardinality*d*2return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = F.max_pool2d(out, 3, 2, 1)out = self.conv2(out)out = self.conv3(out)out = self.conv4(out)out = self.conv5(out)out = F.avg_pool2d(out, 7)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out),dim=1)return out

网络结构:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128BN_Conv2d-3         [-1, 64, 112, 112]               0Conv2d-4          [-1, 128, 56, 56]           8,192BatchNorm2d-5          [-1, 128, 56, 56]             256BN_Conv2d-6          [-1, 128, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           4,608BatchNorm2d-8          [-1, 128, 56, 56]             256BN_Conv2d-9          [-1, 128, 56, 56]               0Conv2d-10          [-1, 256, 56, 56]          33,024BatchNorm2d-11          [-1, 256, 56, 56]             512Conv2d-12          [-1, 256, 56, 56]          16,384BatchNorm2d-13          [-1, 256, 56, 56]             512ResNeXt_Block-14          [-1, 256, 56, 56]               0Conv2d-15          [-1, 128, 56, 56]          32,768BatchNorm2d-16          [-1, 128, 56, 56]             256BN_Conv2d-17          [-1, 128, 56, 56]               0Conv2d-18          [-1, 128, 56, 56]           4,608BatchNorm2d-19          [-1, 128, 56, 56]             256BN_Conv2d-20          [-1, 128, 56, 56]               0Conv2d-21          [-1, 256, 56, 56]          33,024BatchNorm2d-22          [-1, 256, 56, 56]             512Conv2d-23          [-1, 256, 56, 56]          65,536BatchNorm2d-24          [-1, 256, 56, 56]             512ResNeXt_Block-25          [-1, 256, 56, 56]               0Conv2d-26          [-1, 128, 56, 56]          32,768BatchNorm2d-27          [-1, 128, 56, 56]             256BN_Conv2d-28          [-1, 128, 56, 56]               0Conv2d-29          [-1, 128, 56, 56]           4,608BatchNorm2d-30          [-1, 128, 56, 56]             256BN_Conv2d-31          [-1, 128, 56, 56]               0Conv2d-32          [-1, 256, 56, 56]          33,024BatchNorm2d-33          [-1, 256, 56, 56]             512Conv2d-34          [-1, 256, 56, 56]          65,536BatchNorm2d-35          [-1, 256, 56, 56]             512ResNeXt_Block-36          [-1, 256, 56, 56]               0Conv2d-37          [-1, 256, 56, 56]          65,536BatchNorm2d-38          [-1, 256, 56, 56]             512BN_Conv2d-39          [-1, 256, 56, 56]               0Conv2d-40          [-1, 256, 28, 28]          18,432BatchNorm2d-41          [-1, 256, 28, 28]             512BN_Conv2d-42          [-1, 256, 28, 28]               0Conv2d-43          [-1, 512, 28, 28]         131,584BatchNorm2d-44          [-1, 512, 28, 28]           1,024Conv2d-45          [-1, 512, 28, 28]         131,072BatchNorm2d-46          [-1, 512, 28, 28]           1,024ResNeXt_Block-47          [-1, 512, 28, 28]               0Conv2d-48          [-1, 256, 28, 28]         131,072BatchNorm2d-49          [-1, 256, 28, 28]             512BN_Conv2d-50          [-1, 256, 28, 28]               0Conv2d-51          [-1, 256, 28, 28]          18,432BatchNorm2d-52          [-1, 256, 28, 28]             512BN_Conv2d-53          [-1, 256, 28, 28]               0Conv2d-54          [-1, 512, 28, 28]         131,584BatchNorm2d-55          [-1, 512, 28, 28]           1,024Conv2d-56          [-1, 512, 28, 28]         262,144BatchNorm2d-57          [-1, 512, 28, 28]           1,024ResNeXt_Block-58          [-1, 512, 28, 28]               0Conv2d-59          [-1, 256, 28, 28]         131,072BatchNorm2d-60          [-1, 256, 28, 28]             512BN_Conv2d-61          [-1, 256, 28, 28]               0Conv2d-62          [-1, 256, 28, 28]          18,432BatchNorm2d-63          [-1, 256, 28, 28]             512BN_Conv2d-64          [-1, 256, 28, 28]               0Conv2d-65          [-1, 512, 28, 28]         131,584BatchNorm2d-66          [-1, 512, 28, 28]           1,024Conv2d-67          [-1, 512, 28, 28]         262,144BatchNorm2d-68          [-1, 512, 28, 28]           1,024ResNeXt_Block-69          [-1, 512, 28, 28]               0Conv2d-70          [-1, 256, 28, 28]         131,072BatchNorm2d-71          [-1, 256, 28, 28]             512BN_Conv2d-72          [-1, 256, 28, 28]               0Conv2d-73          [-1, 256, 28, 28]          18,432BatchNorm2d-74          [-1, 256, 28, 28]             512BN_Conv2d-75          [-1, 256, 28, 28]               0Conv2d-76          [-1, 512, 28, 28]         131,584BatchNorm2d-77          [-1, 512, 28, 28]           1,024Conv2d-78          [-1, 512, 28, 28]         262,144BatchNorm2d-79          [-1, 512, 28, 28]           1,024ResNeXt_Block-80          [-1, 512, 28, 28]               0Conv2d-81          [-1, 512, 28, 28]         262,144BatchNorm2d-82          [-1, 512, 28, 28]           1,024BN_Conv2d-83          [-1, 512, 28, 28]               0Conv2d-84          [-1, 512, 14, 14]          73,728BatchNorm2d-85          [-1, 512, 14, 14]           1,024BN_Conv2d-86          [-1, 512, 14, 14]               0Conv2d-87         [-1, 1024, 14, 14]         525,312BatchNorm2d-88         [-1, 1024, 14, 14]           2,048Conv2d-89         [-1, 1024, 14, 14]         524,288BatchNorm2d-90         [-1, 1024, 14, 14]           2,048ResNeXt_Block-91         [-1, 1024, 14, 14]               0Conv2d-92          [-1, 512, 14, 14]         524,288BatchNorm2d-93          [-1, 512, 14, 14]           1,024BN_Conv2d-94          [-1, 512, 14, 14]               0Conv2d-95          [-1, 512, 14, 14]          73,728BatchNorm2d-96          [-1, 512, 14, 14]           1,024BN_Conv2d-97          [-1, 512, 14, 14]               0Conv2d-98         [-1, 1024, 14, 14]         525,312BatchNorm2d-99         [-1, 1024, 14, 14]           2,048Conv2d-100         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-101         [-1, 1024, 14, 14]           2,048ResNeXt_Block-102         [-1, 1024, 14, 14]               0Conv2d-103          [-1, 512, 14, 14]         524,288BatchNorm2d-104          [-1, 512, 14, 14]           1,024BN_Conv2d-105          [-1, 512, 14, 14]               0Conv2d-106          [-1, 512, 14, 14]          73,728BatchNorm2d-107          [-1, 512, 14, 14]           1,024BN_Conv2d-108          [-1, 512, 14, 14]               0Conv2d-109         [-1, 1024, 14, 14]         525,312BatchNorm2d-110         [-1, 1024, 14, 14]           2,048Conv2d-111         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-112         [-1, 1024, 14, 14]           2,048ResNeXt_Block-113         [-1, 1024, 14, 14]               0Conv2d-114          [-1, 512, 14, 14]         524,288BatchNorm2d-115          [-1, 512, 14, 14]           1,024BN_Conv2d-116          [-1, 512, 14, 14]               0Conv2d-117          [-1, 512, 14, 14]          73,728BatchNorm2d-118          [-1, 512, 14, 14]           1,024BN_Conv2d-119          [-1, 512, 14, 14]               0Conv2d-120         [-1, 1024, 14, 14]         525,312BatchNorm2d-121         [-1, 1024, 14, 14]           2,048Conv2d-122         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-123         [-1, 1024, 14, 14]           2,048ResNeXt_Block-124         [-1, 1024, 14, 14]               0Conv2d-125          [-1, 512, 14, 14]         524,288BatchNorm2d-126          [-1, 512, 14, 14]           1,024BN_Conv2d-127          [-1, 512, 14, 14]               0Conv2d-128          [-1, 512, 14, 14]          73,728BatchNorm2d-129          [-1, 512, 14, 14]           1,024BN_Conv2d-130          [-1, 512, 14, 14]               0Conv2d-131         [-1, 1024, 14, 14]         525,312BatchNorm2d-132         [-1, 1024, 14, 14]           2,048Conv2d-133         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-134         [-1, 1024, 14, 14]           2,048ResNeXt_Block-135         [-1, 1024, 14, 14]               0Conv2d-136          [-1, 512, 14, 14]         524,288BatchNorm2d-137          [-1, 512, 14, 14]           1,024BN_Conv2d-138          [-1, 512, 14, 14]               0Conv2d-139          [-1, 512, 14, 14]          73,728BatchNorm2d-140          [-1, 512, 14, 14]           1,024BN_Conv2d-141          [-1, 512, 14, 14]               0Conv2d-142         [-1, 1024, 14, 14]         525,312BatchNorm2d-143         [-1, 1024, 14, 14]           2,048Conv2d-144         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-145         [-1, 1024, 14, 14]           2,048ResNeXt_Block-146         [-1, 1024, 14, 14]               0Conv2d-147         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-148         [-1, 1024, 14, 14]           2,048BN_Conv2d-149         [-1, 1024, 14, 14]               0Conv2d-150           [-1, 1024, 7, 7]         294,912BatchNorm2d-151           [-1, 1024, 7, 7]           2,048BN_Conv2d-152           [-1, 1024, 7, 7]               0Conv2d-153           [-1, 2048, 7, 7]       2,099,200BatchNorm2d-154           [-1, 2048, 7, 7]           4,096Conv2d-155           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-156           [-1, 2048, 7, 7]           4,096ResNeXt_Block-157           [-1, 2048, 7, 7]               0Conv2d-158           [-1, 1024, 7, 7]       2,097,152BatchNorm2d-159           [-1, 1024, 7, 7]           2,048BN_Conv2d-160           [-1, 1024, 7, 7]               0Conv2d-161           [-1, 1024, 7, 7]         294,912BatchNorm2d-162           [-1, 1024, 7, 7]           2,048BN_Conv2d-163           [-1, 1024, 7, 7]               0Conv2d-164           [-1, 2048, 7, 7]       2,099,200BatchNorm2d-165           [-1, 2048, 7, 7]           4,096Conv2d-166           [-1, 2048, 7, 7]       4,194,304BatchNorm2d-167           [-1, 2048, 7, 7]           4,096ResNeXt_Block-168           [-1, 2048, 7, 7]               0Conv2d-169           [-1, 1024, 7, 7]       2,097,152BatchNorm2d-170           [-1, 1024, 7, 7]           2,048BN_Conv2d-171           [-1, 1024, 7, 7]               0Conv2d-172           [-1, 1024, 7, 7]         294,912BatchNorm2d-173           [-1, 1024, 7, 7]           2,048BN_Conv2d-174           [-1, 1024, 7, 7]               0Conv2d-175           [-1, 2048, 7, 7]       2,099,200BatchNorm2d-176           [-1, 2048, 7, 7]           4,096Conv2d-177           [-1, 2048, 7, 7]       4,194,304BatchNorm2d-178           [-1, 2048, 7, 7]           4,096ResNeXt_Block-179           [-1, 2048, 7, 7]               0Linear-180                    [-1, 2]           4,098
================================================================
Total params: 37,570,626
Trainable params: 37,570,626
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.37
Params size (MB): 143.32
Estimated Total Size (MB): 523.26
----------------------------------------------------------------

训练代码:

def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_acc,train_loss=0,0for x,y in dataloader:x,y=x.to(device),y.to(device)pred=model(x)loss=loss_fn(pred,y) #计算损失 loss_fn(预测值,真实值)optimizer.zero_grad() #清除梯度loss.backward() #反向传播 计算梯度optimizer.step() #梯度自动更新train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item() #在预测过后会,每一个严格样本会形成对于三种类别的概率,然后取最高的为当前预测值,累加# 假设的 pred(每行是一个样本的类别概率,概率和为1)# pred = torch.tensor([#     [0.1, 0.2, 0.7],  # 第1个样本:类别2的概率最高#     [0.05, 0.15, 0.8], # 第2个样本:类别2的概率最高#     [0.3, 0.6, 0.1],  # 第3个样本:类别1的概率最高#     [0.9, 0.08, 0.02] # 第4个样本:类别0的概率最高# ])train_loss+=loss.item()train_acc/=size #train_acc加的是每一个的值所以要除全部train_loss/=num_batches #加的是批次的值,除以批次return train_acc,train_lossdef test(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)test_acc,test_loss=0,0with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)pred=model(imgs)loss=loss_fn(pred,target)test_loss+=loss.item()test_acc+=(pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
import copy 
import torch, gcoptimizer=torch.optim.AdamW(model.parameters(),lr=1e-4) #优化器
loss_fn=nn.CrossEntropyLoss() #交叉损失函数
epochs=30
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #最好的结果
for epoch in range(epochs):model.train() #训练模式epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)model.eval() #测试模式epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn,optimizer)if epoch_test_acc>best_acc:best_acc=epoch_test_accbest_model=copy.deepcopy(model)#复制模型train_acc.append(epoch_train_acc)test_acc.append(epoch_test_acc)train_loss.append(epoch_train_loss)test_loss.append(epoch_test_loss)lr=optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))gc.collect()torch.cuda.empty_cache()
print('Done')
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

http://www.dtcms.com/a/412141.html

相关文章:

  • 【开题答辩全过程】以 “勤工有道”微信小程序为例,包含答辩的问题和答案
  • 淄博网站建设费用摄影网站开发综述
  • 个人模板网站wordpress主题免费吗
  • 基于英飞凌PSOC Control C3的高速吹风机变频控制方案
  • 服装企业网站建设的目的网站宣传搭建
  • 有没有做维修的网站一个服务器可以做两个网站
  • display version 概念、故障排错及题目
  • 沈阳网站开发培训多少钱网站主页布局
  • 快速排名网站系统如何做不同域名跳转同一个网站
  • 网赌网站怎么做网站备案是域名备案还是空间备案
  • 第八章 惊喜12 信息传达
  • 贵池区城乡与住房建设网站做网站 图片格式
  • 济南网站建设推广报价分局网站建设
  • CAD画图:006标注修改
  • Leecode hot100 - 22. 括号生成
  • 【复习】计网每日一题--随机访问
  • [NeurIPS‘25] AI infra / ML sys 论文(解析)合集
  • 网站案例展示2k屏幕的网站怎么做
  • 南宁手机建站模板做网站的作品思路及步骤
  • 太白 网站建设百度显示网站名
  • 网站服务器的选择h5 WordPress
  • 阿里最新开源Wan2.2-Animate-14B 本地部署教程:统一双模态框架,MoE架构赋能电影级角色动画与替换
  • iree的编译
  • 学生制作网站建设 维护wordpress文章怎么生成标签
  • 专业企业网站建设定制wordpress微信分享图片不显示图片
  • 陕煤化建设集团网站云南建设网站首页
  • 网站在线留言系统如何进入网站管理员界面
  • 网站免费模块kloxo网站压缩
  • 烟台制作网站的公司哪家好软件开发面试问题大全及答案大全
  • typeid学习