当前位置: 首页 > news >正文

DenseNet网络

前言:

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

本文为个人理解,如有错误,感谢指正

前置代码:

import torch
import torch.nn as nn
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
warnings.filterwarnings('ignore')
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
deviceimport matplotlib.pyplot as plt
from PIL import Imageimage_path='./2data_1/data/0Normal'
image_file=[f for f in os.listdir(image_path) if f.endswith(('.jpg','.png','.jpeg'))]
# os.listdir(image_path)
# image_file #将对应文件夹所有的图片文件储存fig,axes=plt.subplots(3,8,figsize=(16,6)) #3行8列,大小
for ax,img_file in zip(axes.flat,image_file):img_paths=os.path.join(image_path,img_file) #路径拼接img=Image.open(img_paths)ax.imshow(img)ax.axis('off')plt.tight_layout() #紧致一点
plt.show()data_path='./2data_1/data'
# 读取数据并进行格式转换
train_form=transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
total_data=datasets.ImageFolder(data_path,transform=train_form)
total_datatotal_data.class_to_idx# 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_datasetbatch_size=4
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True
)for x,y in test_dl:print('x的形状',x.shape)print('y的形状',y.shape)break

DenseNet网络其实相当于残差神经网络的改良版,原来残差神经网络是将卷积后的特征与原特征相加,从而去避免训练问题,而DenseNet网络则是通过特征合并对这些问题进行解决,即将原特征与新产生的特征进行合并,从而使得特征更加丰富,达到更好的实现效果。

网络结构组成第一部分 denselayer层

denselayer层被称为密集层,是构成denseNet网络的基本单元,DenseNet网络结构:bn1层,relu1层,conv1层,bn2层,relu2层,conv2层,这六层组成了一个基本的DenseNet网络单元,其中,conv1层是以输入特征的维度作为该层的输入通道, 以bn_size*growth_rate为输出通道的卷积核大小为1*1,步长为1*1的卷积层,其中bn_size是瓶颈层的放大系数,为自己设定的数值,growth_rate是设定的增长率,通过两者相乘,使得输出通道扩大,使数据进行升维。再通过conv2层,conv2层是以bn_size*growth_rate为输入通道,growth_rate为输出通道的,卷积核大小为3*3的,步长为1的卷积层,通过该通道,对由conv1输出的高维度的特征进行降维,变为维度大小为growth_rate的特征同时通过3*3的卷积核进行特征提取。最后返回的数据是将原数据x和新数据new_feature在第一维度(也就是通道数)上进行合并,使得原来的特征维度和现在的新特征维度融合(原来的x的特征为[-1,3,224,224],在经过(in_channel=3,growth_rate=32,bn_size=5)后得到的最后特征在第一维度上的值为3+32=35),这一步是DenseNet网络的核心步骤。

# 密集层
class DenseLayer(nn.Sequential):#这不是哪个卷积残差连接层升维的吗 ,扩大了通道数def __init__(self,in_channel,growth_rate,bn_size,drop_rate):super(DenseLayer,self).__init__()#名字为norm1的bn层self.add_module('norm1',nn.BatchNorm2d(in_channel))self.add_module('relu1',nn.ReLU(inplace=True))#放大系数*新增通道数self.add_module('conv1',nn.Conv2d(in_channel,bn_size*growth_rate,kernel_size=1,stride=1,bias=False))self.add_module('norm2',nn.BatchNorm2d(bn_size*growth_rate))self.add_module('relu2',nn.ReLU(inplace=True))self.add_module('conv2',nn.Conv2d(bn_size*growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False))self.drop_rate=drop_ratedef forward(self,x):new_feature=super(DenseLayer,self).forward(x)if self.drop_rate>0:#dropout层new_feature=F.dropout(new_feature,p=self.drop_rate,training=self.training)return torch.cat([x,new_feature],1)

网络结构组成第二部分 denseblock块

denseblock被称为稠密块,其作用就是将denselayer层构建起来,循环num_layers次,每次构建一个denselayer层,denselayer层的每一层的输入特征会随着循环不断增大而增大,使得每一次训练的情况都会有所不同,使得处于这一稠密块的每一层denselayer都能发挥自己的作用。

class DenseBlock(nn.Sequential):def __init__(self,num_layers,in_channel,bn_size,growth_rate,drop_rate):super(DenseBlock,self).__init__()for i in range(num_layers):layer=DenseLayer(in_channel+i*growth_rate,growth_rate,bn_size,drop_rate)self.add_module(f'denselayer{i+1}',layer)

网络结构组成第三部分 Transition

Transition部分被称为过渡卷积层,其主要组成部分有四层分别是bn层,relu层,conv层,pool层其中,conv的卷积层是由传入的输入输出通道决定的,卷积核大小为1*1,步长为1的卷积层。这一层在DenseNet网络中起到的主要是降维作用,通过1*1的卷积核将输入通道数转为输出通道数,实现特征的缩减。

class Transition(nn.Sequential):def __init__(self,in_channel,out_channel):super(Transition,self).__init__()self.add_module('norm',nn.BatchNorm2d(in_channel))self.add_module('relu',nn.ReLU(inplace=True))self.add_module('conv',nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False))self.add_module('pool',nn.AvgPool2d(2,stride=2))

整体的网络结构

首先定义了开始的网络层conv0,norm0,relu0,pool0,其中conv0为输入通道为3,输出通道为init_channel的卷积核大小为7,步长为2的卷积层,先通过这第一块的结构对特征进行提取,然后接下来构建稠密块,将每一个稠密块内的denselayer层的所有特征合并,将合并完的特征通过transition进行降维操作,循环构建所有的稠密块,然后通过循环对每一层的参数根据层的特性进行初始化,避免一些后续的训练问题。

class DenseNet(nn.Module):def __init__(self,growth_rate=32,block_config=(6,12,24,16),init_channel=64,bn_size=4,comperssion_rate=0.5,drop_rate=0,num_classes=1000):super(DenseNet,self).__init__()self.features=nn.Sequential(OrderedDict([('conv0',nn.Conv2d(3,init_channel,kernel_size=7,stride=2,padding=3,bias=False)),('norm0',nn.BatchNorm2d(init_channel)),('relu0',nn.ReLU(inplace=True)),('pool0',nn.MaxPool2d(3,stride=2,padding=1))]))num_features=init_channelfor i,num_layers in enumerate(block_config):# 遍历block_config的每一个参数block # DenseBlock(num_layers=6,num_features=init_channel,4,32,0)block=DenseBlock(num_layers,num_features,bn_size,growth_rate,drop_rate)self.features.add_module(f'denseblock{i+1}',block)# 将每一块的DenseBlock产生的所有denselayer的特征累加起来,也就是所有的denselayer层num_features+=num_layers*growth_rateif i !=len(block_config)-1# 将融合后的特征通过Transition进行降维transition=Transition(num_features,int(num_features*comperssion_rate))self.features.add_module(f'transition{i+1}',transition)num_features=int(num_features*comperssion_rate)self.features.add_module('norm5',nn.BatchNorm2d(num_features))self.features.add_module('relu5',nn.ReLU(inplace=True))self.classifier=nn.Linear(num_features,num_classes)for m in self.modules():if isinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m,nn.BatchNorm2d):nn.init.constant_(m.bias,0)nn.init.constant_(m.weight,1)elif isinstance(m,nn.Linear):nn.init.constant_(m.bias,0)def forward(self,x):x=self.features(x)x=F.avg_pool2d(x,7,stride=1).view(x.size(0),-1)x=self.classifier(x)return x
densenet121=DenseNet(init_channel=64,growth_rate=32,block_config=(6,12,24,16),num_classes=2)
model=densenet121.to(device)
model
import torchsummary as summary
summary.summary(model,(3,224,224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0BatchNorm2d-5           [-1, 64, 56, 56]             128ReLU-6           [-1, 64, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10           [-1, 32, 56, 56]          36,864BatchNorm2d-11           [-1, 96, 56, 56]             192ReLU-12           [-1, 96, 56, 56]               0Conv2d-13          [-1, 128, 56, 56]          12,288BatchNorm2d-14          [-1, 128, 56, 56]             256ReLU-15          [-1, 128, 56, 56]               0Conv2d-16           [-1, 32, 56, 56]          36,864BatchNorm2d-17          [-1, 128, 56, 56]             256ReLU-18          [-1, 128, 56, 56]               0Conv2d-19          [-1, 128, 56, 56]          16,384BatchNorm2d-20          [-1, 128, 56, 56]             256ReLU-21          [-1, 128, 56, 56]               0Conv2d-22           [-1, 32, 56, 56]          36,864BatchNorm2d-23          [-1, 160, 56, 56]             320ReLU-24          [-1, 160, 56, 56]               0Conv2d-25          [-1, 128, 56, 56]          20,480BatchNorm2d-26          [-1, 128, 56, 56]             256ReLU-27          [-1, 128, 56, 56]               0Conv2d-28           [-1, 32, 56, 56]          36,864BatchNorm2d-29          [-1, 192, 56, 56]             384ReLU-30          [-1, 192, 56, 56]               0Conv2d-31          [-1, 128, 56, 56]          24,576BatchNorm2d-32          [-1, 128, 56, 56]             256ReLU-33          [-1, 128, 56, 56]               0Conv2d-34           [-1, 32, 56, 56]          36,864BatchNorm2d-35          [-1, 224, 56, 56]             448ReLU-36          [-1, 224, 56, 56]               0Conv2d-37          [-1, 128, 56, 56]          28,672BatchNorm2d-38          [-1, 128, 56, 56]             256ReLU-39          [-1, 128, 56, 56]               0Conv2d-40           [-1, 32, 56, 56]          36,864BatchNorm2d-41          [-1, 256, 56, 56]             512ReLU-42          [-1, 256, 56, 56]               0Conv2d-43          [-1, 128, 56, 56]          32,768AvgPool2d-44          [-1, 128, 28, 28]               0BatchNorm2d-45          [-1, 128, 28, 28]             256ReLU-46          [-1, 128, 28, 28]               0Conv2d-47          [-1, 128, 28, 28]          16,384BatchNorm2d-48          [-1, 128, 28, 28]             256ReLU-49          [-1, 128, 28, 28]               0Conv2d-50           [-1, 32, 28, 28]          36,864BatchNorm2d-51          [-1, 160, 28, 28]             320ReLU-52          [-1, 160, 28, 28]               0Conv2d-53          [-1, 128, 28, 28]          20,480BatchNorm2d-54          [-1, 128, 28, 28]             256ReLU-55          [-1, 128, 28, 28]               0Conv2d-56           [-1, 32, 28, 28]          36,864BatchNorm2d-57          [-1, 192, 28, 28]             384ReLU-58          [-1, 192, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          24,576BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62           [-1, 32, 28, 28]          36,864BatchNorm2d-63          [-1, 224, 28, 28]             448ReLU-64          [-1, 224, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]          28,672BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68           [-1, 32, 28, 28]          36,864BatchNorm2d-69          [-1, 256, 28, 28]             512ReLU-70          [-1, 256, 28, 28]               0Conv2d-71          [-1, 128, 28, 28]          32,768BatchNorm2d-72          [-1, 128, 28, 28]             256ReLU-73          [-1, 128, 28, 28]               0Conv2d-74           [-1, 32, 28, 28]          36,864BatchNorm2d-75          [-1, 288, 28, 28]             576ReLU-76          [-1, 288, 28, 28]               0Conv2d-77          [-1, 128, 28, 28]          36,864BatchNorm2d-78          [-1, 128, 28, 28]             256ReLU-79          [-1, 128, 28, 28]               0Conv2d-80           [-1, 32, 28, 28]          36,864BatchNorm2d-81          [-1, 320, 28, 28]             640ReLU-82          [-1, 320, 28, 28]               0Conv2d-83          [-1, 128, 28, 28]          40,960BatchNorm2d-84          [-1, 128, 28, 28]             256ReLU-85          [-1, 128, 28, 28]               0Conv2d-86           [-1, 32, 28, 28]          36,864BatchNorm2d-87          [-1, 352, 28, 28]             704ReLU-88          [-1, 352, 28, 28]               0Conv2d-89          [-1, 128, 28, 28]          45,056BatchNorm2d-90          [-1, 128, 28, 28]             256ReLU-91          [-1, 128, 28, 28]               0Conv2d-92           [-1, 32, 28, 28]          36,864BatchNorm2d-93          [-1, 384, 28, 28]             768ReLU-94          [-1, 384, 28, 28]               0Conv2d-95          [-1, 128, 28, 28]          49,152BatchNorm2d-96          [-1, 128, 28, 28]             256ReLU-97          [-1, 128, 28, 28]               0Conv2d-98           [-1, 32, 28, 28]          36,864BatchNorm2d-99          [-1, 416, 28, 28]             832ReLU-100          [-1, 416, 28, 28]               0Conv2d-101          [-1, 128, 28, 28]          53,248BatchNorm2d-102          [-1, 128, 28, 28]             256ReLU-103          [-1, 128, 28, 28]               0Conv2d-104           [-1, 32, 28, 28]          36,864BatchNorm2d-105          [-1, 448, 28, 28]             896ReLU-106          [-1, 448, 28, 28]               0Conv2d-107          [-1, 128, 28, 28]          57,344BatchNorm2d-108          [-1, 128, 28, 28]             256ReLU-109          [-1, 128, 28, 28]               0Conv2d-110           [-1, 32, 28, 28]          36,864BatchNorm2d-111          [-1, 480, 28, 28]             960ReLU-112          [-1, 480, 28, 28]               0Conv2d-113          [-1, 128, 28, 28]          61,440BatchNorm2d-114          [-1, 128, 28, 28]             256ReLU-115          [-1, 128, 28, 28]               0Conv2d-116           [-1, 32, 28, 28]          36,864BatchNorm2d-117          [-1, 512, 28, 28]           1,024ReLU-118          [-1, 512, 28, 28]               0Conv2d-119          [-1, 256, 28, 28]         131,072AvgPool2d-120          [-1, 256, 14, 14]               0BatchNorm2d-121          [-1, 256, 14, 14]             512ReLU-122          [-1, 256, 14, 14]               0Conv2d-123          [-1, 128, 14, 14]          32,768BatchNorm2d-124          [-1, 128, 14, 14]             256ReLU-125          [-1, 128, 14, 14]               0Conv2d-126           [-1, 32, 14, 14]          36,864BatchNorm2d-127          [-1, 288, 14, 14]             576ReLU-128          [-1, 288, 14, 14]               0Conv2d-129          [-1, 128, 14, 14]          36,864BatchNorm2d-130          [-1, 128, 14, 14]             256ReLU-131          [-1, 128, 14, 14]               0Conv2d-132           [-1, 32, 14, 14]          36,864BatchNorm2d-133          [-1, 320, 14, 14]             640ReLU-134          [-1, 320, 14, 14]               0Conv2d-135          [-1, 128, 14, 14]          40,960BatchNorm2d-136          [-1, 128, 14, 14]             256ReLU-137          [-1, 128, 14, 14]               0Conv2d-138           [-1, 32, 14, 14]          36,864BatchNorm2d-139          [-1, 352, 14, 14]             704ReLU-140          [-1, 352, 14, 14]               0Conv2d-141          [-1, 128, 14, 14]          45,056BatchNorm2d-142          [-1, 128, 14, 14]             256ReLU-143          [-1, 128, 14, 14]               0Conv2d-144           [-1, 32, 14, 14]          36,864BatchNorm2d-145          [-1, 384, 14, 14]             768ReLU-146          [-1, 384, 14, 14]               0Conv2d-147          [-1, 128, 14, 14]          49,152BatchNorm2d-148          [-1, 128, 14, 14]             256ReLU-149          [-1, 128, 14, 14]               0Conv2d-150           [-1, 32, 14, 14]          36,864BatchNorm2d-151          [-1, 416, 14, 14]             832ReLU-152          [-1, 416, 14, 14]               0Conv2d-153          [-1, 128, 14, 14]          53,248BatchNorm2d-154          [-1, 128, 14, 14]             256ReLU-155          [-1, 128, 14, 14]               0Conv2d-156           [-1, 32, 14, 14]          36,864BatchNorm2d-157          [-1, 448, 14, 14]             896ReLU-158          [-1, 448, 14, 14]               0Conv2d-159          [-1, 128, 14, 14]          57,344BatchNorm2d-160          [-1, 128, 14, 14]             256ReLU-161          [-1, 128, 14, 14]               0Conv2d-162           [-1, 32, 14, 14]          36,864BatchNorm2d-163          [-1, 480, 14, 14]             960ReLU-164          [-1, 480, 14, 14]               0Conv2d-165          [-1, 128, 14, 14]          61,440BatchNorm2d-166          [-1, 128, 14, 14]             256ReLU-167          [-1, 128, 14, 14]               0Conv2d-168           [-1, 32, 14, 14]          36,864BatchNorm2d-169          [-1, 512, 14, 14]           1,024ReLU-170          [-1, 512, 14, 14]               0Conv2d-171          [-1, 128, 14, 14]          65,536BatchNorm2d-172          [-1, 128, 14, 14]             256ReLU-173          [-1, 128, 14, 14]               0Conv2d-174           [-1, 32, 14, 14]          36,864BatchNorm2d-175          [-1, 544, 14, 14]           1,088ReLU-176          [-1, 544, 14, 14]               0Conv2d-177          [-1, 128, 14, 14]          69,632BatchNorm2d-178          [-1, 128, 14, 14]             256ReLU-179          [-1, 128, 14, 14]               0Conv2d-180           [-1, 32, 14, 14]          36,864BatchNorm2d-181          [-1, 576, 14, 14]           1,152ReLU-182          [-1, 576, 14, 14]               0Conv2d-183          [-1, 128, 14, 14]          73,728BatchNorm2d-184          [-1, 128, 14, 14]             256ReLU-185          [-1, 128, 14, 14]               0Conv2d-186           [-1, 32, 14, 14]          36,864BatchNorm2d-187          [-1, 608, 14, 14]           1,216ReLU-188          [-1, 608, 14, 14]               0Conv2d-189          [-1, 128, 14, 14]          77,824BatchNorm2d-190          [-1, 128, 14, 14]             256ReLU-191          [-1, 128, 14, 14]               0Conv2d-192           [-1, 32, 14, 14]          36,864BatchNorm2d-193          [-1, 640, 14, 14]           1,280ReLU-194          [-1, 640, 14, 14]               0Conv2d-195          [-1, 128, 14, 14]          81,920BatchNorm2d-196          [-1, 128, 14, 14]             256ReLU-197          [-1, 128, 14, 14]               0Conv2d-198           [-1, 32, 14, 14]          36,864BatchNorm2d-199          [-1, 672, 14, 14]           1,344ReLU-200          [-1, 672, 14, 14]               0Conv2d-201          [-1, 128, 14, 14]          86,016BatchNorm2d-202          [-1, 128, 14, 14]             256ReLU-203          [-1, 128, 14, 14]               0Conv2d-204           [-1, 32, 14, 14]          36,864BatchNorm2d-205          [-1, 704, 14, 14]           1,408ReLU-206          [-1, 704, 14, 14]               0Conv2d-207          [-1, 128, 14, 14]          90,112BatchNorm2d-208          [-1, 128, 14, 14]             256ReLU-209          [-1, 128, 14, 14]               0Conv2d-210           [-1, 32, 14, 14]          36,864BatchNorm2d-211          [-1, 736, 14, 14]           1,472ReLU-212          [-1, 736, 14, 14]               0Conv2d-213          [-1, 128, 14, 14]          94,208BatchNorm2d-214          [-1, 128, 14, 14]             256ReLU-215          [-1, 128, 14, 14]               0Conv2d-216           [-1, 32, 14, 14]          36,864BatchNorm2d-217          [-1, 768, 14, 14]           1,536ReLU-218          [-1, 768, 14, 14]               0Conv2d-219          [-1, 128, 14, 14]          98,304BatchNorm2d-220          [-1, 128, 14, 14]             256ReLU-221          [-1, 128, 14, 14]               0Conv2d-222           [-1, 32, 14, 14]          36,864BatchNorm2d-223          [-1, 800, 14, 14]           1,600ReLU-224          [-1, 800, 14, 14]               0Conv2d-225          [-1, 128, 14, 14]         102,400BatchNorm2d-226          [-1, 128, 14, 14]             256ReLU-227          [-1, 128, 14, 14]               0Conv2d-228           [-1, 32, 14, 14]          36,864BatchNorm2d-229          [-1, 832, 14, 14]           1,664ReLU-230          [-1, 832, 14, 14]               0Conv2d-231          [-1, 128, 14, 14]         106,496BatchNorm2d-232          [-1, 128, 14, 14]             256ReLU-233          [-1, 128, 14, 14]               0Conv2d-234           [-1, 32, 14, 14]          36,864BatchNorm2d-235          [-1, 864, 14, 14]           1,728ReLU-236          [-1, 864, 14, 14]               0Conv2d-237          [-1, 128, 14, 14]         110,592BatchNorm2d-238          [-1, 128, 14, 14]             256ReLU-239          [-1, 128, 14, 14]               0Conv2d-240           [-1, 32, 14, 14]          36,864BatchNorm2d-241          [-1, 896, 14, 14]           1,792ReLU-242          [-1, 896, 14, 14]               0Conv2d-243          [-1, 128, 14, 14]         114,688BatchNorm2d-244          [-1, 128, 14, 14]             256ReLU-245          [-1, 128, 14, 14]               0Conv2d-246           [-1, 32, 14, 14]          36,864BatchNorm2d-247          [-1, 928, 14, 14]           1,856ReLU-248          [-1, 928, 14, 14]               0Conv2d-249          [-1, 128, 14, 14]         118,784BatchNorm2d-250          [-1, 128, 14, 14]             256ReLU-251          [-1, 128, 14, 14]               0Conv2d-252           [-1, 32, 14, 14]          36,864BatchNorm2d-253          [-1, 960, 14, 14]           1,920ReLU-254          [-1, 960, 14, 14]               0Conv2d-255          [-1, 128, 14, 14]         122,880BatchNorm2d-256          [-1, 128, 14, 14]             256ReLU-257          [-1, 128, 14, 14]               0Conv2d-258           [-1, 32, 14, 14]          36,864BatchNorm2d-259          [-1, 992, 14, 14]           1,984ReLU-260          [-1, 992, 14, 14]               0Conv2d-261          [-1, 128, 14, 14]         126,976BatchNorm2d-262          [-1, 128, 14, 14]             256ReLU-263          [-1, 128, 14, 14]               0Conv2d-264           [-1, 32, 14, 14]          36,864BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Conv2d-267          [-1, 512, 14, 14]         524,288AvgPool2d-268            [-1, 512, 7, 7]               0BatchNorm2d-269            [-1, 512, 7, 7]           1,024ReLU-270            [-1, 512, 7, 7]               0Conv2d-271            [-1, 128, 7, 7]          65,536BatchNorm2d-272            [-1, 128, 7, 7]             256ReLU-273            [-1, 128, 7, 7]               0Conv2d-274             [-1, 32, 7, 7]          36,864BatchNorm2d-275            [-1, 544, 7, 7]           1,088ReLU-276            [-1, 544, 7, 7]               0Conv2d-277            [-1, 128, 7, 7]          69,632BatchNorm2d-278            [-1, 128, 7, 7]             256ReLU-279            [-1, 128, 7, 7]               0Conv2d-280             [-1, 32, 7, 7]          36,864BatchNorm2d-281            [-1, 576, 7, 7]           1,152ReLU-282            [-1, 576, 7, 7]               0Conv2d-283            [-1, 128, 7, 7]          73,728BatchNorm2d-284            [-1, 128, 7, 7]             256ReLU-285            [-1, 128, 7, 7]               0Conv2d-286             [-1, 32, 7, 7]          36,864BatchNorm2d-287            [-1, 608, 7, 7]           1,216ReLU-288            [-1, 608, 7, 7]               0Conv2d-289            [-1, 128, 7, 7]          77,824BatchNorm2d-290            [-1, 128, 7, 7]             256ReLU-291            [-1, 128, 7, 7]               0Conv2d-292             [-1, 32, 7, 7]          36,864BatchNorm2d-293            [-1, 640, 7, 7]           1,280ReLU-294            [-1, 640, 7, 7]               0Conv2d-295            [-1, 128, 7, 7]          81,920BatchNorm2d-296            [-1, 128, 7, 7]             256ReLU-297            [-1, 128, 7, 7]               0Conv2d-298             [-1, 32, 7, 7]          36,864BatchNorm2d-299            [-1, 672, 7, 7]           1,344ReLU-300            [-1, 672, 7, 7]               0Conv2d-301            [-1, 128, 7, 7]          86,016BatchNorm2d-302            [-1, 128, 7, 7]             256ReLU-303            [-1, 128, 7, 7]               0Conv2d-304             [-1, 32, 7, 7]          36,864BatchNorm2d-305            [-1, 704, 7, 7]           1,408ReLU-306            [-1, 704, 7, 7]               0Conv2d-307            [-1, 128, 7, 7]          90,112BatchNorm2d-308            [-1, 128, 7, 7]             256ReLU-309            [-1, 128, 7, 7]               0Conv2d-310             [-1, 32, 7, 7]          36,864BatchNorm2d-311            [-1, 736, 7, 7]           1,472ReLU-312            [-1, 736, 7, 7]               0Conv2d-313            [-1, 128, 7, 7]          94,208BatchNorm2d-314            [-1, 128, 7, 7]             256ReLU-315            [-1, 128, 7, 7]               0Conv2d-316             [-1, 32, 7, 7]          36,864BatchNorm2d-317            [-1, 768, 7, 7]           1,536ReLU-318            [-1, 768, 7, 7]               0Conv2d-319            [-1, 128, 7, 7]          98,304BatchNorm2d-320            [-1, 128, 7, 7]             256ReLU-321            [-1, 128, 7, 7]               0Conv2d-322             [-1, 32, 7, 7]          36,864BatchNorm2d-323            [-1, 800, 7, 7]           1,600ReLU-324            [-1, 800, 7, 7]               0Conv2d-325            [-1, 128, 7, 7]         102,400BatchNorm2d-326            [-1, 128, 7, 7]             256ReLU-327            [-1, 128, 7, 7]               0Conv2d-328             [-1, 32, 7, 7]          36,864BatchNorm2d-329            [-1, 832, 7, 7]           1,664ReLU-330            [-1, 832, 7, 7]               0Conv2d-331            [-1, 128, 7, 7]         106,496BatchNorm2d-332            [-1, 128, 7, 7]             256ReLU-333            [-1, 128, 7, 7]               0Conv2d-334             [-1, 32, 7, 7]          36,864BatchNorm2d-335            [-1, 864, 7, 7]           1,728ReLU-336            [-1, 864, 7, 7]               0Conv2d-337            [-1, 128, 7, 7]         110,592BatchNorm2d-338            [-1, 128, 7, 7]             256ReLU-339            [-1, 128, 7, 7]               0Conv2d-340             [-1, 32, 7, 7]          36,864BatchNorm2d-341            [-1, 896, 7, 7]           1,792ReLU-342            [-1, 896, 7, 7]               0Conv2d-343            [-1, 128, 7, 7]         114,688BatchNorm2d-344            [-1, 128, 7, 7]             256ReLU-345            [-1, 128, 7, 7]               0Conv2d-346             [-1, 32, 7, 7]          36,864BatchNorm2d-347            [-1, 928, 7, 7]           1,856ReLU-348            [-1, 928, 7, 7]               0Conv2d-349            [-1, 128, 7, 7]         118,784BatchNorm2d-350            [-1, 128, 7, 7]             256ReLU-351            [-1, 128, 7, 7]               0Conv2d-352             [-1, 32, 7, 7]          36,864BatchNorm2d-353            [-1, 960, 7, 7]           1,920ReLU-354            [-1, 960, 7, 7]               0Conv2d-355            [-1, 128, 7, 7]         122,880BatchNorm2d-356            [-1, 128, 7, 7]             256ReLU-357            [-1, 128, 7, 7]               0Conv2d-358             [-1, 32, 7, 7]          36,864BatchNorm2d-359            [-1, 992, 7, 7]           1,984ReLU-360            [-1, 992, 7, 7]               0Conv2d-361            [-1, 128, 7, 7]         126,976BatchNorm2d-362            [-1, 128, 7, 7]             256ReLU-363            [-1, 128, 7, 7]               0Conv2d-364             [-1, 32, 7, 7]          36,864BatchNorm2d-365           [-1, 1024, 7, 7]           2,048ReLU-366           [-1, 1024, 7, 7]               0Linear-367                    [-1, 2]           2,050
================================================================
Total params: 6,955,906
Trainable params: 6,955,906
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.57
Params size (MB): 26.53
Estimated Total Size (MB): 321.68
----------------------------------------------------------------

训练代码:

def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_acc,train_loss=0,0for x,y in dataloader:x,y=x.to(device),y.to(device)pred=model(x)loss=loss_fn(pred,y) #计算损失 loss_fn(预测值,真实值)optimizer.zero_grad() #清除梯度loss.backward() #反向传播 计算梯度optimizer.step() #梯度自动更新train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item() #在预测过后会,每一个严格样本会形成对于三种类别的概率,然后取最高的为当前预测值,累加# 假设的 pred(每行是一个样本的类别概率,概率和为1)# pred = torch.tensor([#     [0.1, 0.2, 0.7],  # 第1个样本:类别2的概率最高#     [0.05, 0.15, 0.8], # 第2个样本:类别2的概率最高#     [0.3, 0.6, 0.1],  # 第3个样本:类别1的概率最高#     [0.9, 0.08, 0.02] # 第4个样本:类别0的概率最高# ])train_loss+=loss.item()train_acc/=size #train_acc加的是每一个的值所以要除全部train_loss/=num_batches #加的是批次的值,除以批次return train_acc,train_lossdef test(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)test_acc,test_loss=0,0with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)pred=model(imgs)loss=loss_fn(pred,target)test_loss+=loss.item()test_acc+=(pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
import copy 
optimizer=torch.optim.AdamW(model.parameters(),lr=1e-4) #优化器
loss_fn=nn.CrossEntropyLoss() #交叉损失函数
epochs=30
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #最好的结果
for epoch in range(epochs):model.train() #训练模式epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)model.eval() #测试模式epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn,optimizer)if epoch_test_acc>best_acc:best_acc=epoch_test_accbest_model=copy.deepcopy(model) #复制模型train_acc.append(epoch_train_acc)test_acc.append(epoch_test_acc)train_loss.append(epoch_train_loss)test_loss.append(epoch_test_loss)lr=optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

可以看到,相对残差神经网络,效果还是相对好一点的。


文章转载自:

http://jDloblPI.krxzL.cn
http://NZce9K2X.krxzL.cn
http://3HB3zcIC.krxzL.cn
http://RR1sEwCY.krxzL.cn
http://gSKaZOqv.krxzL.cn
http://y0Erv5sG.krxzL.cn
http://xanMEs7D.krxzL.cn
http://VXpgzcnl.krxzL.cn
http://uYP3Ep7G.krxzL.cn
http://hPPr8IwV.krxzL.cn
http://Zy5YtKXq.krxzL.cn
http://tQPQz9ol.krxzL.cn
http://clPoH5VX.krxzL.cn
http://4FRF3Qpa.krxzL.cn
http://y2VtudZy.krxzL.cn
http://4TI7po0M.krxzL.cn
http://5ylnmQDa.krxzL.cn
http://MH4uMqyJ.krxzL.cn
http://j2V3ndCo.krxzL.cn
http://GilBliUO.krxzL.cn
http://yub8DySt.krxzL.cn
http://yQya9nxA.krxzL.cn
http://QNXuXWdS.krxzL.cn
http://wYnMUAFK.krxzL.cn
http://EXNKO4zc.krxzL.cn
http://QfYe58xP.krxzL.cn
http://AYQpnpI2.krxzL.cn
http://esqYtXkG.krxzL.cn
http://HBGTLQEk.krxzL.cn
http://vTaMz9Zn.krxzL.cn
http://www.dtcms.com/a/378892.html

相关文章:

  • 2025胶水分装机服务商技术解析:聚焦高精度、智能化应用
  • Drawnix白板本地部署指南:cpolar实现远程创意协作
  • leetcode189.轮转数组
  • SPI设备驱动
  • 第七节,探索 ​​CSS 的高级特性、复杂布局技巧、性能优化以及与现代前端工作流的整合(二)
  • O3.2 opencv高阶
  • c语言,识别到黑色就自动开枪,4399单击游戏狙击战场,源码分享,豆包ai出品
  • Spring Boot 原理与性能优化实战
  • PHP 性能优化实战 OPcache + FPM 极限优化配置
  • solidity的高阶语法(完结篇)
  • 端–边–云一体的实时音视频转发:多路RTSP转RTMP推送技术深度剖析
  • OPC Client第10讲:实现主界面;获取初始界面传来的所有配置信息config【C++读写Excel:xlnx;ODBC;缓冲区】
  • git的使用命令
  • uniapp | 实现微信小程序端的分包处理
  • C/C++项目练习:命令行记账本
  • mes之生产管理
  • 【51单片机】【protues仿真】基于51单片机多功能电子秤系统
  • VSCode 下 PlatformIO 的使用
  • Shell编程:生成10个随机数,并判断最大值和最小值
  • nginx参数介绍(Nginx配置文件结构、nginx命令)
  • Java mp4parser 实现视频mp4 切割
  • 安卓13_ROM修改定制化-----系统升级(OTA 更新)后保留 Magisk 的 root 权限和相关功能
  • Codebuddy Code CLI 实战体验:从安装到生成俄罗斯方块小游戏
  • 【代码随想录day 24】 力扣 90. 集合II
  • [iOS] 属性关键字
  • MVC及其衍生
  • 前端开发为什么要禁止使用 == 操作符?
  • langchain4j入门(跟随官网学习)第一章
  • ASSIGN (LV_NAME) TO <FS_NAME>. 通过变量名动态访问变量
  • 二、WPF——Style样式玩法(通过资源字典将Style独立,全局调用)