当前位置: 首页 > news >正文

第J6周:ResNeXt-50实战

目录

前言

1.检查GPU

2.查看数据 

3.划分数据集

4.创建模型 

5.编译及训练模型

6.结果可视化

7.总结


前言

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.检查GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
device(type='cpu')

2.查看数据 

import os,PIL,random,pathlibdata_dir = '具体路径'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
classeNames

3.划分数据集

total_datadir = 'data/45-data'train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_datatrain_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_datasettrain_size,test_sizebatch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break
Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4.创建模型 

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedConvolutionBlock(nn.Module):def __init__(self, in_channels, out_channels, strides, groups):super(GroupedConvolutionBlock, self).__init__()self.groups = groupsself.g_channels = out_channels // groupsself.conv_layers = nn.ModuleList([nn.Conv2d(self.g_channels, self.g_channels, kernel_size=3, stride=strides, padding=1, bias=False)for _ in range(groups)])self.bn = nn.BatchNorm2d(out_channels, eps=1.001e-5)self.relu = nn.ReLU()def forward(self, x):group_list = []# 分组进行卷积for c in range(self.groups):# 分组取出数据x_group = x[:, c * self.g_channels:(c + 1) * self.g_channels, :, :]# 分组进行卷积x_group = self.conv_layers[c](x_group)# 存入listgroup_list.append(x_group)# 合并list中的数据group_merge = torch.cat(group_list, dim=1)x = self.bn(group_merge)x = self.relu(x)return xclass Block(nn.Module):def __init__(self, in_channels, filters, strides=1, groups=32, conv_shortcut=True):super(Block, self).__init__()self.conv_shortcut = conv_shortcutif conv_shortcut:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=strides, bias=False),nn.BatchNorm2d(filters * 2, eps=1.001e-5))else:self.shortcut = nn.Identity()self.conv1 = nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False)self.bn1 = nn.BatchNorm2d(filters, eps=1.001e-5)self.relu1 = nn.ReLU()self.grouped_conv = GroupedConvolutionBlock(filters, filters, strides, groups)self.conv2 = nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False)self.bn2 = nn.BatchNorm2d(filters * 2, eps=1.001e-5)self.relu2 = nn.ReLU()def forward(self, x):shortcut = self.shortcut(x)x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.grouped_conv(x)x = self.conv2(x)x = self.bn2(x)x = x + shortcutx = self.relu2(x)return xclass Stack(nn.Module):def __init__(self, in_channels, filters, blocks, strides, groups=32):super(Stack, self).__init__()self.blocks = nn.ModuleList()self.blocks.append(Block(in_channels, filters, strides, groups, conv_shortcut=True))for _ in range(1, blocks):self.blocks.append(Block(filters * 2, filters, strides=1, groups=groups, conv_shortcut=False))def forward(self, x):for block in self.blocks:x = block(x)return xclass ResNext50(nn.Module):def __init__(self, input_shape, num_classes):super(ResNext50, self).__init__()self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64, eps=1.001e-5)self.relu1 = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.stack1 = Stack(64, 128, 2, 1)self.stack2 = Stack(256, 256, 3, 2)self.stack3 = Stack(512, 512, 5, 2)self.stack4 = Stack(1024, 1024, 2, 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(2048, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.maxpool(x)x = self.stack1(x)x = self.stack2(x)x = self.stack3(x)x = self.stack4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xfrom torchsummary import summarymodel=ResNext50(input_shape=(224,224,3),num_classes=1000)model = ResNext50(input_shape=(3, 224, 224), num_classes=1000)# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 打印模型摘要
summary(model, input_size=(3, 224, 224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0Conv2d-5          [-1, 256, 56, 56]          16,384BatchNorm2d-6          [-1, 256, 56, 56]             512Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10            [-1, 4, 56, 56]             144Conv2d-11            [-1, 4, 56, 56]             144Conv2d-12            [-1, 4, 56, 56]             144Conv2d-13            [-1, 4, 56, 56]             144Conv2d-14            [-1, 4, 56, 56]             144Conv2d-15            [-1, 4, 56, 56]             144Conv2d-16            [-1, 4, 56, 56]             144Conv2d-17            [-1, 4, 56, 56]             144Conv2d-18            [-1, 4, 56, 56]             144Conv2d-19            [-1, 4, 56, 56]             144Conv2d-20            [-1, 4, 56, 56]             144Conv2d-21            [-1, 4, 56, 56]             144Conv2d-22            [-1, 4, 56, 56]             144Conv2d-23            [-1, 4, 56, 56]             144Conv2d-24            [-1, 4, 56, 56]             144Conv2d-25            [-1, 4, 56, 56]             144Conv2d-26            [-1, 4, 56, 56]             144Conv2d-27            [-1, 4, 56, 56]             144Conv2d-28            [-1, 4, 56, 56]             144Conv2d-29            [-1, 4, 56, 56]             144Conv2d-30            [-1, 4, 56, 56]             144Conv2d-31            [-1, 4, 56, 56]             144Conv2d-32            [-1, 4, 56, 56]             144Conv2d-33            [-1, 4, 56, 56]             144Conv2d-34            [-1, 4, 56, 56]             144Conv2d-35            [-1, 4, 56, 56]             144Conv2d-36            [-1, 4, 56, 56]             144Conv2d-37            [-1, 4, 56, 56]             144Conv2d-38            [-1, 4, 56, 56]             144Conv2d-39            [-1, 4, 56, 56]             144Conv2d-40            [-1, 4, 56, 56]             144Conv2d-41            [-1, 4, 56, 56]             144BatchNorm2d-42          [-1, 128, 56, 56]             256ReLU-43          [-1, 128, 56, 56]               0
GroupedConvolutionBlock-44          [-1, 128, 56, 56]               0Conv2d-45          [-1, 256, 56, 56]          32,768BatchNorm2d-46          [-1, 256, 56, 56]             512ReLU-47          [-1, 256, 56, 56]               0Block-48          [-1, 256, 56, 56]               0Identity-49          [-1, 256, 56, 56]               0Conv2d-50          [-1, 128, 56, 56]          32,768BatchNorm2d-51          [-1, 128, 56, 56]             256ReLU-52          [-1, 128, 56, 56]               0Conv2d-53            [-1, 4, 56, 56]             144Conv2d-54            [-1, 4, 56, 56]             144Conv2d-55            [-1, 4, 56, 56]             144Conv2d-56            [-1, 4, 56, 56]             144Conv2d-57            [-1, 4, 56, 56]             144Conv2d-58            [-1, 4, 56, 56]             144Conv2d-59            [-1, 4, 56, 56]             144Conv2d-60            [-1, 4, 56, 56]             144Conv2d-61            [-1, 4, 56, 56]             144Conv2d-62            [-1, 4, 56, 56]             144Conv2d-63            [-1, 4, 56, 56]             144Conv2d-64            [-1, 4, 56, 56]             144Conv2d-65            [-1, 4, 56, 56]             144Conv2d-66            [-1, 4, 56, 56]             144Conv2d-67            [-1, 4, 56, 56]             144Conv2d-68            [-1, 4, 56, 56]             144Conv2d-69            [-1, 4, 56, 56]             144Conv2d-70            [-1, 4, 56, 56]             144Conv2d-71            [-1, 4, 56, 56]             144Conv2d-72            [-1, 4, 56, 56]             144Conv2d-73            [-1, 4, 56, 56]             144Conv2d-74            [-1, 4, 56, 56]             144Conv2d-75            [-1, 4, 56, 56]             144Conv2d-76            [-1, 4, 56, 56]             144Conv2d-77            [-1, 4, 56, 56]             144Conv2d-78            [-1, 4, 56, 56]             144Conv2d-79            [-1, 4, 56, 56]             144Conv2d-80            [-1, 4, 56, 56]             144Conv2d-81            [-1, 4, 56, 56]             144Conv2d-82            [-1, 4, 56, 56]             144Conv2d-83            [-1, 4, 56, 56]             144Conv2d-84            [-1, 4, 56, 56]             144BatchNorm2d-85          [-1, 128, 56, 56]             256ReLU-86          [-1, 128, 56, 56]               0
GroupedConvolutionBlock-87          [-1, 128, 56, 56]               0Conv2d-88          [-1, 256, 56, 56]          32,768BatchNorm2d-89          [-1, 256, 56, 56]             512ReLU-90          [-1, 256, 56, 56]               0Block-91          [-1, 256, 56, 56]               0Stack-92          [-1, 256, 56, 56]               0Conv2d-93          [-1, 512, 28, 28]         131,072BatchNorm2d-94          [-1, 512, 28, 28]           1,024Conv2d-95          [-1, 256, 56, 56]          65,536BatchNorm2d-96          [-1, 256, 56, 56]             512ReLU-97          [-1, 256, 56, 56]               0Conv2d-98            [-1, 8, 28, 28]             576Conv2d-99            [-1, 8, 28, 28]             576Conv2d-100            [-1, 8, 28, 28]             576Conv2d-101            [-1, 8, 28, 28]             576Conv2d-102            [-1, 8, 28, 28]             576Conv2d-103            [-1, 8, 28, 28]             576Conv2d-104            [-1, 8, 28, 28]             576Conv2d-105            [-1, 8, 28, 28]             576Conv2d-106            [-1, 8, 28, 28]             576Conv2d-107            [-1, 8, 28, 28]             576Conv2d-108            [-1, 8, 28, 28]             576Conv2d-109            [-1, 8, 28, 28]             576Conv2d-110            [-1, 8, 28, 28]             576Conv2d-111            [-1, 8, 28, 28]             576Conv2d-112            [-1, 8, 28, 28]             576Conv2d-113            [-1, 8, 28, 28]             576Conv2d-114            [-1, 8, 28, 28]             576Conv2d-115            [-1, 8, 28, 28]             576Conv2d-116            [-1, 8, 28, 28]             576Conv2d-117            [-1, 8, 28, 28]             576Conv2d-118            [-1, 8, 28, 28]             576Conv2d-119            [-1, 8, 28, 28]             576Conv2d-120            [-1, 8, 28, 28]             576Conv2d-121            [-1, 8, 28, 28]             576Conv2d-122            [-1, 8, 28, 28]             576Conv2d-123            [-1, 8, 28, 28]             576Conv2d-124            [-1, 8, 28, 28]             576Conv2d-125            [-1, 8, 28, 28]             576Conv2d-126            [-1, 8, 28, 28]             576Conv2d-127            [-1, 8, 28, 28]             576Conv2d-128            [-1, 8, 28, 28]             576Conv2d-129            [-1, 8, 28, 28]             576BatchNorm2d-130          [-1, 256, 28, 28]             512ReLU-131          [-1, 256, 28, 28]               0
GroupedConvolutionBlock-132          [-1, 256, 28, 28]               0Conv2d-133          [-1, 512, 28, 28]         131,072BatchNorm2d-134          [-1, 512, 28, 28]           1,024ReLU-135          [-1, 512, 28, 28]               0Block-136          [-1, 512, 28, 28]               0Identity-137          [-1, 512, 28, 28]               0Conv2d-138          [-1, 256, 28, 28]         131,072BatchNorm2d-139          [-1, 256, 28, 28]             512ReLU-140          [-1, 256, 28, 28]               0Conv2d-141            [-1, 8, 28, 28]             576Conv2d-142            [-1, 8, 28, 28]             576Conv2d-143            [-1, 8, 28, 28]             576Conv2d-144            [-1, 8, 28, 28]             576Conv2d-145            [-1, 8, 28, 28]             576Conv2d-146            [-1, 8, 28, 28]             576Conv2d-147            [-1, 8, 28, 28]             576Conv2d-148            [-1, 8, 28, 28]             576Conv2d-149            [-1, 8, 28, 28]             576Conv2d-150            [-1, 8, 28, 28]             576Conv2d-151            [-1, 8, 28, 28]             576Conv2d-152            [-1, 8, 28, 28]             576Conv2d-153            [-1, 8, 28, 28]             576Conv2d-154            [-1, 8, 28, 28]             576Conv2d-155            [-1, 8, 28, 28]             576Conv2d-156            [-1, 8, 28, 28]             576Conv2d-157            [-1, 8, 28, 28]             576Conv2d-158            [-1, 8, 28, 28]             576Conv2d-159            [-1, 8, 28, 28]             576Conv2d-160            [-1, 8, 28, 28]             576Conv2d-161            [-1, 8, 28, 28]             576Conv2d-162            [-1, 8, 28, 28]             576Conv2d-163            [-1, 8, 28, 28]             576Conv2d-164            [-1, 8, 28, 28]             576Conv2d-165            [-1, 8, 28, 28]             576Conv2d-166            [-1, 8, 28, 28]             576Conv2d-167            [-1, 8, 28, 28]             576Conv2d-168            [-1, 8, 28, 28]             576Conv2d-169            [-1, 8, 28, 28]             576Conv2d-170            [-1, 8, 28, 28]             576Conv2d-171            [-1, 8, 28, 28]             576Conv2d-172            [-1, 8, 28, 28]             576BatchNorm2d-173          [-1, 256, 28, 28]             512ReLU-174          [-1, 256, 28, 28]               0
GroupedConvolutionBlock-175          [-1, 256, 28, 28]               0Conv2d-176          [-1, 512, 28, 28]         131,072BatchNorm2d-177          [-1, 512, 28, 28]           1,024ReLU-178          [-1, 512, 28, 28]               0Block-179          [-1, 512, 28, 28]               0Identity-180          [-1, 512, 28, 28]               0Conv2d-181          [-1, 256, 28, 28]         131,072BatchNorm2d-182          [-1, 256, 28, 28]             512ReLU-183          [-1, 256, 28, 28]               0Conv2d-184            [-1, 8, 28, 28]             576Conv2d-185            [-1, 8, 28, 28]             576Conv2d-186            [-1, 8, 28, 28]             576Conv2d-187            [-1, 8, 28, 28]             576Conv2d-188            [-1, 8, 28, 28]             576Conv2d-189            [-1, 8, 28, 28]             576Conv2d-190            [-1, 8, 28, 28]             576Conv2d-191            [-1, 8, 28, 28]             576Conv2d-192            [-1, 8, 28, 28]             576Conv2d-193            [-1, 8, 28, 28]             576Conv2d-194            [-1, 8, 28, 28]             576Conv2d-195            [-1, 8, 28, 28]             576Conv2d-196            [-1, 8, 28, 28]             576Conv2d-197            [-1, 8, 28, 28]             576Conv2d-198            [-1, 8, 28, 28]             576Conv2d-199            [-1, 8, 28, 28]             576Conv2d-200            [-1, 8, 28, 28]             576Conv2d-201            [-1, 8, 28, 28]             576Conv2d-202            [-1, 8, 28, 28]             576Conv2d-203            [-1, 8, 28, 28]             576Conv2d-204            [-1, 8, 28, 28]             576Conv2d-205            [-1, 8, 28, 28]             576Conv2d-206            [-1, 8, 28, 28]             576Conv2d-207            [-1, 8, 28, 28]             576Conv2d-208            [-1, 8, 28, 28]             576Conv2d-209            [-1, 8, 28, 28]             576Conv2d-210            [-1, 8, 28, 28]             576Conv2d-211            [-1, 8, 28, 28]             576Conv2d-212            [-1, 8, 28, 28]             576Conv2d-213            [-1, 8, 28, 28]             576Conv2d-214            [-1, 8, 28, 28]             576Conv2d-215            [-1, 8, 28, 28]             576BatchNorm2d-216          [-1, 256, 28, 28]             512ReLU-217          [-1, 256, 28, 28]               0
GroupedConvolutionBlock-218          [-1, 256, 28, 28]               0Conv2d-219          [-1, 512, 28, 28]         131,072BatchNorm2d-220          [-1, 512, 28, 28]           1,024ReLU-221          [-1, 512, 28, 28]               0Block-222          [-1, 512, 28, 28]               0Stack-223          [-1, 512, 28, 28]               0Conv2d-224         [-1, 1024, 14, 14]         524,288BatchNorm2d-225         [-1, 1024, 14, 14]           2,048Conv2d-226          [-1, 512, 28, 28]         262,144BatchNorm2d-227          [-1, 512, 28, 28]           1,024ReLU-228          [-1, 512, 28, 28]               0Conv2d-229           [-1, 16, 14, 14]           2,304Conv2d-230           [-1, 16, 14, 14]           2,304Conv2d-231           [-1, 16, 14, 14]           2,304Conv2d-232           [-1, 16, 14, 14]           2,304Conv2d-233           [-1, 16, 14, 14]           2,304Conv2d-234           [-1, 16, 14, 14]           2,304Conv2d-235           [-1, 16, 14, 14]           2,304Conv2d-236           [-1, 16, 14, 14]           2,304Conv2d-237           [-1, 16, 14, 14]           2,304Conv2d-238           [-1, 16, 14, 14]           2,304Conv2d-239           [-1, 16, 14, 14]           2,304Conv2d-240           [-1, 16, 14, 14]           2,304Conv2d-241           [-1, 16, 14, 14]           2,304Conv2d-242           [-1, 16, 14, 14]           2,304Conv2d-243           [-1, 16, 14, 14]           2,304Conv2d-244           [-1, 16, 14, 14]           2,304Conv2d-245           [-1, 16, 14, 14]           2,304Conv2d-246           [-1, 16, 14, 14]           2,304Conv2d-247           [-1, 16, 14, 14]           2,304Conv2d-248           [-1, 16, 14, 14]           2,304Conv2d-249           [-1, 16, 14, 14]           2,304Conv2d-250           [-1, 16, 14, 14]           2,304Conv2d-251           [-1, 16, 14, 14]           2,304Conv2d-252           [-1, 16, 14, 14]           2,304Conv2d-253           [-1, 16, 14, 14]           2,304Conv2d-254           [-1, 16, 14, 14]           2,304Conv2d-255           [-1, 16, 14, 14]           2,304Conv2d-256           [-1, 16, 14, 14]           2,304Conv2d-257           [-1, 16, 14, 14]           2,304Conv2d-258           [-1, 16, 14, 14]           2,304Conv2d-259           [-1, 16, 14, 14]           2,304Conv2d-260           [-1, 16, 14, 14]           2,304BatchNorm2d-261          [-1, 512, 14, 14]           1,024ReLU-262          [-1, 512, 14, 14]               0
GroupedConvolutionBlock-263          [-1, 512, 14, 14]               0Conv2d-264         [-1, 1024, 14, 14]         524,288BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Block-267         [-1, 1024, 14, 14]               0Identity-268         [-1, 1024, 14, 14]               0Conv2d-269          [-1, 512, 14, 14]         524,288BatchNorm2d-270          [-1, 512, 14, 14]           1,024ReLU-271          [-1, 512, 14, 14]               0Conv2d-272           [-1, 16, 14, 14]           2,304Conv2d-273           [-1, 16, 14, 14]           2,304Conv2d-274           [-1, 16, 14, 14]           2,304Conv2d-275           [-1, 16, 14, 14]           2,304Conv2d-276           [-1, 16, 14, 14]           2,304Conv2d-277           [-1, 16, 14, 14]           2,304Conv2d-278           [-1, 16, 14, 14]           2,304Conv2d-279           [-1, 16, 14, 14]           2,304Conv2d-280           [-1, 16, 14, 14]           2,304Conv2d-281           [-1, 16, 14, 14]           2,304Conv2d-282           [-1, 16, 14, 14]           2,304Conv2d-283           [-1, 16, 14, 14]           2,304Conv2d-284           [-1, 16, 14, 14]           2,304Conv2d-285           [-1, 16, 14, 14]           2,304Conv2d-286           [-1, 16, 14, 14]           2,304Conv2d-287           [-1, 16, 14, 14]           2,304Conv2d-288           [-1, 16, 14, 14]           2,304Conv2d-289           [-1, 16, 14, 14]           2,304Conv2d-290           [-1, 16, 14, 14]           2,304Conv2d-291           [-1, 16, 14, 14]           2,304Conv2d-292           [-1, 16, 14, 14]           2,304Conv2d-293           [-1, 16, 14, 14]           2,304Conv2d-294           [-1, 16, 14, 14]           2,304Conv2d-295           [-1, 16, 14, 14]           2,304Conv2d-296           [-1, 16, 14, 14]           2,304Conv2d-297           [-1, 16, 14, 14]           2,304Conv2d-298           [-1, 16, 14, 14]           2,304Conv2d-299           [-1, 16, 14, 14]           2,304Conv2d-300           [-1, 16, 14, 14]           2,304Conv2d-301           [-1, 16, 14, 14]           2,304Conv2d-302           [-1, 16, 14, 14]           2,304Conv2d-303           [-1, 16, 14, 14]           2,304BatchNorm2d-304          [-1, 512, 14, 14]           1,024ReLU-305          [-1, 512, 14, 14]               0
GroupedConvolutionBlock-306          [-1, 512, 14, 14]               0Conv2d-307         [-1, 1024, 14, 14]         524,288BatchNorm2d-308         [-1, 1024, 14, 14]           2,048ReLU-309         [-1, 1024, 14, 14]               0Block-310         [-1, 1024, 14, 14]               0Identity-311         [-1, 1024, 14, 14]               0Conv2d-312          [-1, 512, 14, 14]         524,288BatchNorm2d-313          [-1, 512, 14, 14]           1,024ReLU-314          [-1, 512, 14, 14]               0Conv2d-315           [-1, 16, 14, 14]           2,304Conv2d-316           [-1, 16, 14, 14]           2,304Conv2d-317           [-1, 16, 14, 14]           2,304Conv2d-318           [-1, 16, 14, 14]           2,304Conv2d-319           [-1, 16, 14, 14]           2,304Conv2d-320           [-1, 16, 14, 14]           2,304Conv2d-321           [-1, 16, 14, 14]           2,304Conv2d-322           [-1, 16, 14, 14]           2,304Conv2d-323           [-1, 16, 14, 14]           2,304Conv2d-324           [-1, 16, 14, 14]           2,304Conv2d-325           [-1, 16, 14, 14]           2,304Conv2d-326           [-1, 16, 14, 14]           2,304Conv2d-327           [-1, 16, 14, 14]           2,304Conv2d-328           [-1, 16, 14, 14]           2,304Conv2d-329           [-1, 16, 14, 14]           2,304Conv2d-330           [-1, 16, 14, 14]           2,304Conv2d-331           [-1, 16, 14, 14]           2,304Conv2d-332           [-1, 16, 14, 14]           2,304Conv2d-333           [-1, 16, 14, 14]           2,304Conv2d-334           [-1, 16, 14, 14]           2,304Conv2d-335           [-1, 16, 14, 14]           2,304Conv2d-336           [-1, 16, 14, 14]           2,304Conv2d-337           [-1, 16, 14, 14]           2,304Conv2d-338           [-1, 16, 14, 14]           2,304Conv2d-339           [-1, 16, 14, 14]           2,304Conv2d-340           [-1, 16, 14, 14]           2,304Conv2d-341           [-1, 16, 14, 14]           2,304Conv2d-342           [-1, 16, 14, 14]           2,304Conv2d-343           [-1, 16, 14, 14]           2,304Conv2d-344           [-1, 16, 14, 14]           2,304Conv2d-345           [-1, 16, 14, 14]           2,304Conv2d-346           [-1, 16, 14, 14]           2,304BatchNorm2d-347          [-1, 512, 14, 14]           1,024ReLU-348          [-1, 512, 14, 14]               0
GroupedConvolutionBlock-349          [-1, 512, 14, 14]               0Conv2d-350         [-1, 1024, 14, 14]         524,288BatchNorm2d-351         [-1, 1024, 14, 14]           2,048ReLU-352         [-1, 1024, 14, 14]               0Block-353         [-1, 1024, 14, 14]               0Identity-354         [-1, 1024, 14, 14]               0Conv2d-355          [-1, 512, 14, 14]         524,288BatchNorm2d-356          [-1, 512, 14, 14]           1,024ReLU-357          [-1, 512, 14, 14]               0Conv2d-358           [-1, 16, 14, 14]           2,304Conv2d-359           [-1, 16, 14, 14]           2,304Conv2d-360           [-1, 16, 14, 14]           2,304Conv2d-361           [-1, 16, 14, 14]           2,304Conv2d-362           [-1, 16, 14, 14]           2,304Conv2d-363           [-1, 16, 14, 14]           2,304Conv2d-364           [-1, 16, 14, 14]           2,304Conv2d-365           [-1, 16, 14, 14]           2,304Conv2d-366           [-1, 16, 14, 14]           2,304Conv2d-367           [-1, 16, 14, 14]           2,304Conv2d-368           [-1, 16, 14, 14]           2,304Conv2d-369           [-1, 16, 14, 14]           2,304Conv2d-370           [-1, 16, 14, 14]           2,304Conv2d-371           [-1, 16, 14, 14]           2,304Conv2d-372           [-1, 16, 14, 14]           2,304Conv2d-373           [-1, 16, 14, 14]           2,304Conv2d-374           [-1, 16, 14, 14]           2,304Conv2d-375           [-1, 16, 14, 14]           2,304Conv2d-376           [-1, 16, 14, 14]           2,304Conv2d-377           [-1, 16, 14, 14]           2,304Conv2d-378           [-1, 16, 14, 14]           2,304Conv2d-379           [-1, 16, 14, 14]           2,304Conv2d-380           [-1, 16, 14, 14]           2,304Conv2d-381           [-1, 16, 14, 14]           2,304Conv2d-382           [-1, 16, 14, 14]           2,304Conv2d-383           [-1, 16, 14, 14]           2,304Conv2d-384           [-1, 16, 14, 14]           2,304Conv2d-385           [-1, 16, 14, 14]           2,304Conv2d-386           [-1, 16, 14, 14]           2,304Conv2d-387           [-1, 16, 14, 14]           2,304Conv2d-388           [-1, 16, 14, 14]           2,304Conv2d-389           [-1, 16, 14, 14]           2,304BatchNorm2d-390          [-1, 512, 14, 14]           1,024ReLU-391          [-1, 512, 14, 14]               0
GroupedConvolutionBlock-392          [-1, 512, 14, 14]               0Conv2d-393         [-1, 1024, 14, 14]         524,288BatchNorm2d-394         [-1, 1024, 14, 14]           2,048ReLU-395         [-1, 1024, 14, 14]               0Block-396         [-1, 1024, 14, 14]               0Identity-397         [-1, 1024, 14, 14]               0Conv2d-398          [-1, 512, 14, 14]         524,288BatchNorm2d-399          [-1, 512, 14, 14]           1,024ReLU-400          [-1, 512, 14, 14]               0Conv2d-401           [-1, 16, 14, 14]           2,304Conv2d-402           [-1, 16, 14, 14]           2,304Conv2d-403           [-1, 16, 14, 14]           2,304Conv2d-404           [-1, 16, 14, 14]           2,304Conv2d-405           [-1, 16, 14, 14]           2,304Conv2d-406           [-1, 16, 14, 14]           2,304Conv2d-407           [-1, 16, 14, 14]           2,304Conv2d-408           [-1, 16, 14, 14]           2,304Conv2d-409           [-1, 16, 14, 14]           2,304Conv2d-410           [-1, 16, 14, 14]           2,304Conv2d-411           [-1, 16, 14, 14]           2,304Conv2d-412           [-1, 16, 14, 14]           2,304Conv2d-413           [-1, 16, 14, 14]           2,304Conv2d-414           [-1, 16, 14, 14]           2,304Conv2d-415           [-1, 16, 14, 14]           2,304Conv2d-416           [-1, 16, 14, 14]           2,304Conv2d-417           [-1, 16, 14, 14]           2,304Conv2d-418           [-1, 16, 14, 14]           2,304Conv2d-419           [-1, 16, 14, 14]           2,304Conv2d-420           [-1, 16, 14, 14]           2,304Conv2d-421           [-1, 16, 14, 14]           2,304Conv2d-422           [-1, 16, 14, 14]           2,304Conv2d-423           [-1, 16, 14, 14]           2,304Conv2d-424           [-1, 16, 14, 14]           2,304Conv2d-425           [-1, 16, 14, 14]           2,304Conv2d-426           [-1, 16, 14, 14]           2,304Conv2d-427           [-1, 16, 14, 14]           2,304Conv2d-428           [-1, 16, 14, 14]           2,304Conv2d-429           [-1, 16, 14, 14]           2,304Conv2d-430           [-1, 16, 14, 14]           2,304Conv2d-431           [-1, 16, 14, 14]           2,304Conv2d-432           [-1, 16, 14, 14]           2,304BatchNorm2d-433          [-1, 512, 14, 14]           1,024ReLU-434          [-1, 512, 14, 14]               0
GroupedConvolutionBlock-435          [-1, 512, 14, 14]               0Conv2d-436         [-1, 1024, 14, 14]         524,288BatchNorm2d-437         [-1, 1024, 14, 14]           2,048ReLU-438         [-1, 1024, 14, 14]               0Block-439         [-1, 1024, 14, 14]               0Stack-440         [-1, 1024, 14, 14]               0Conv2d-441           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-442           [-1, 2048, 7, 7]           4,096Conv2d-443         [-1, 1024, 14, 14]       1,048,576BatchNorm2d-444         [-1, 1024, 14, 14]           2,048ReLU-445         [-1, 1024, 14, 14]               0Conv2d-446             [-1, 32, 7, 7]           9,216Conv2d-447             [-1, 32, 7, 7]           9,216Conv2d-448             [-1, 32, 7, 7]           9,216Conv2d-449             [-1, 32, 7, 7]           9,216Conv2d-450             [-1, 32, 7, 7]           9,216Conv2d-451             [-1, 32, 7, 7]           9,216Conv2d-452             [-1, 32, 7, 7]           9,216Conv2d-453             [-1, 32, 7, 7]           9,216Conv2d-454             [-1, 32, 7, 7]           9,216Conv2d-455             [-1, 32, 7, 7]           9,216Conv2d-456             [-1, 32, 7, 7]           9,216Conv2d-457             [-1, 32, 7, 7]           9,216Conv2d-458             [-1, 32, 7, 7]           9,216Conv2d-459             [-1, 32, 7, 7]           9,216Conv2d-460             [-1, 32, 7, 7]           9,216Conv2d-461             [-1, 32, 7, 7]           9,216Conv2d-462             [-1, 32, 7, 7]           9,216Conv2d-463             [-1, 32, 7, 7]           9,216Conv2d-464             [-1, 32, 7, 7]           9,216Conv2d-465             [-1, 32, 7, 7]           9,216Conv2d-466             [-1, 32, 7, 7]           9,216Conv2d-467             [-1, 32, 7, 7]           9,216Conv2d-468             [-1, 32, 7, 7]           9,216Conv2d-469             [-1, 32, 7, 7]           9,216Conv2d-470             [-1, 32, 7, 7]           9,216Conv2d-471             [-1, 32, 7, 7]           9,216Conv2d-472             [-1, 32, 7, 7]           9,216Conv2d-473             [-1, 32, 7, 7]           9,216Conv2d-474             [-1, 32, 7, 7]           9,216Conv2d-475             [-1, 32, 7, 7]           9,216Conv2d-476             [-1, 32, 7, 7]           9,216Conv2d-477             [-1, 32, 7, 7]           9,216BatchNorm2d-478           [-1, 1024, 7, 7]           2,048ReLU-479           [-1, 1024, 7, 7]               0
GroupedConvolutionBlock-480           [-1, 1024, 7, 7]               0Conv2d-481           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-482           [-1, 2048, 7, 7]           4,096ReLU-483           [-1, 2048, 7, 7]               0Block-484           [-1, 2048, 7, 7]               0Identity-485           [-1, 2048, 7, 7]               0Conv2d-486           [-1, 1024, 7, 7]       2,097,152BatchNorm2d-487           [-1, 1024, 7, 7]           2,048ReLU-488           [-1, 1024, 7, 7]               0Conv2d-489             [-1, 32, 7, 7]           9,216Conv2d-490             [-1, 32, 7, 7]           9,216Conv2d-491             [-1, 32, 7, 7]           9,216Conv2d-492             [-1, 32, 7, 7]           9,216Conv2d-493             [-1, 32, 7, 7]           9,216Conv2d-494             [-1, 32, 7, 7]           9,216Conv2d-495             [-1, 32, 7, 7]           9,216Conv2d-496             [-1, 32, 7, 7]           9,216Conv2d-497             [-1, 32, 7, 7]           9,216Conv2d-498             [-1, 32, 7, 7]           9,216Conv2d-499             [-1, 32, 7, 7]           9,216Conv2d-500             [-1, 32, 7, 7]           9,216Conv2d-501             [-1, 32, 7, 7]           9,216Conv2d-502             [-1, 32, 7, 7]           9,216Conv2d-503             [-1, 32, 7, 7]           9,216Conv2d-504             [-1, 32, 7, 7]           9,216Conv2d-505             [-1, 32, 7, 7]           9,216Conv2d-506             [-1, 32, 7, 7]           9,216Conv2d-507             [-1, 32, 7, 7]           9,216Conv2d-508             [-1, 32, 7, 7]           9,216Conv2d-509             [-1, 32, 7, 7]           9,216Conv2d-510             [-1, 32, 7, 7]           9,216Conv2d-511             [-1, 32, 7, 7]           9,216Conv2d-512             [-1, 32, 7, 7]           9,216Conv2d-513             [-1, 32, 7, 7]           9,216Conv2d-514             [-1, 32, 7, 7]           9,216Conv2d-515             [-1, 32, 7, 7]           9,216Conv2d-516             [-1, 32, 7, 7]           9,216Conv2d-517             [-1, 32, 7, 7]           9,216Conv2d-518             [-1, 32, 7, 7]           9,216Conv2d-519             [-1, 32, 7, 7]           9,216Conv2d-520             [-1, 32, 7, 7]           9,216BatchNorm2d-521           [-1, 1024, 7, 7]           2,048ReLU-522           [-1, 1024, 7, 7]               0
GroupedConvolutionBlock-523           [-1, 1024, 7, 7]               0Conv2d-524           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-525           [-1, 2048, 7, 7]           4,096ReLU-526           [-1, 2048, 7, 7]               0Block-527           [-1, 2048, 7, 7]               0Stack-528           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-529           [-1, 2048, 1, 1]               0Linear-530                 [-1, 1000]       2,049,000
================================================================
Total params: 19,051,304
Trainable params: 19,051,304
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 327.33
Params size (MB): 72.67
Estimated Total Size (MB): 400.58
----------------------------------------------------------------

5.编译及训练模型

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_lossepochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:35.4%, Train_loss:5.256, Test_acc:53.4%,Test_loss:4.276
Epoch: 2, Train_acc:55.5%, Train_loss:2.817, Test_acc:52.0%,Test_loss:2.263

。。。。。。


6.结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


7.总结

        这次的实验,我选择将 SE(Squeeze-and-Excitation)模块集成到 ResNeXt-50 网络中,以探索其对模型表达能力和分类精度的提升效果。

         ResNeXt-50 是一种融合了 ResNet 的残差连接思想与 Inception 式多分支结构的网络架构,其核心思想在于引入 Cardinality(基数),即增加“路径的数目”而不是仅堆叠更多的层或通道数,从而以较低计算成本获得更强的表现能力。在这种结构中,每一个 bottleneck 块本质上是多个分支的聚合输出,这为嵌入注意力机制提供了良好的基础。

        我将 SE 模块嵌入到每一个 ResNeXt bottleneck 的输出之后,即在残差加和操作之前先进行通道注意力加权。这样做的原因是,SE 模块会根据当前特征的全局信息重新调整每个通道的权重,而 residual 加法操作是在全通道层面进行的。因此,先引入 SE 模块可以使注意力机制更直接地作用于主分支的特征流。

        SE 模块本身的结构简单但有效:通过全局平均池化压缩空间维度,生成一个长度为 C 的向量(C 为通道数),然后依次通过两个全连接层和一个 Sigmoid 激活函数,生成通道权重系数。这个权重再与原特征图进行逐通道相乘,从而实现“强调重要、抑制冗余”的特征调整。

        从实验结果来看,加入 SE 模块后的 ResNeXt-50 在多个标准数据集(如 CIFAR-100、ImageNet 子集)上表现出更加稳定的收敛速度和略微提升的最终精度。这种提升尤其体现在模型对细粒度类别的识别上,说明 SE 模块在强调重要通道特征方面确实发挥了作用。

        整体来看,将 SE 模块应用于 ResNeXt-50 是一种结构上自然、实现上高效、效果上稳健的改进路径。它充分利用了 ResNeXt 的多分支表达能力和 SE 的通道动态建模机制,是一种值得借鉴的融合策略。

相关文章:

  • C++ 动态内存管理详讲
  • 数字智慧方案5877丨智慧交通项目方案(122页PPT)(文末有下载方式)
  • SX24C01.UG-PXI程控电阻桥板卡
  • 英一真题阅读单词笔记 20-21年
  • 藏文词云生成器学习实践
  • 2000-2020年全国各地级市资本存量测算数据(以2000年为基期)(含原始数据+计算过程+结果)
  • 数字智慧方案5875丨智慧交通枢纽综合解决方案(43页PPT)(文末有下载方式)
  • ValueError: expected sequence of length 8 at dim 2 (got 0)
  • 历史数据分析——运输服务
  • B站Michale_ee——ESP32_IDF SDK——FreeRTOS_6 任务通知同步、任务通知值
  • Qt QGraphicsScene 的用法
  • 分享国产AI工作流集成数据库完成业务处理
  • 常见工业汽车行业通讯接口一览表
  • 珠江桥牌闪耀第137届广交会,展现中国味道与创新活力
  • 【Redis】Hash哈希
  • YOLO旋转目标检测之ONNX模型推理
  • 基于SpringBoot+Vue实现的电影推荐平台功能一
  • 通过组策略使能长路径
  • re题(52)BUUCTF-[FlareOn5]Minesweeper Championship Registration
  • 数据结构学习笔记
  • 巴菲特批评贸易保护主义:贸易不该被当成武器来使用
  • 出现这几个症状,说明你真的老了
  • 校方就退60件演出服道歉:承诺回收服装承担相关费用,已达成和解
  • 准80后遵义市自然资源局局长陈清松任仁怀市委副书记、代市长
  • 澳大利亚大选今日投票:聚焦生活成本与“特朗普问题”
  • 云南省政协原党组成员、秘书长车志敏接受审查调查