pytorch深度学习-ResNet残差网络-CIFAR-10
目录
基础概念
一、梯度消失:传话游戏里的信息丢失
1. 什么是 “梯度”?
2. 梯度消失为什么像传话游戏?
二、ResNet:给信息开 “绿色通道” 的聪明设计
1. ResNet 的核心思想:不从头学,只学 “差异”
2. 残差块:信息的 “双车道高速公路”
3. 为什么 ResNet 能解决梯度消失?
三、ResNet 代码中的 “捷径” 长啥样?
网络架构
一、输入层:乐高积木的「基础底板」
场景:用乐高搭建机器人,第一步是铺好底板
二、残差块:乐高机器人的「模块化关节」
核心:解决「积木搭太高容易散架」的问题(梯度消失)
ResNet18 用的是「Basic Block」,分两种情况:
情况 1:积木尺寸不变(如 layer1 的块)
结构:「双轨道拼接」的关节
情况 2:积木尺寸缩小(如 layer2-3-4 的块,需要「适配器」)
结构:「轨道 + 适配器」的关节
三、18 层的准确计数(只数带孔的积木层)
1. 初始卷积层(1 层)
2. 四个残差块组(16 层)
残差块组结构
残差块核心设计
3. 全连接层(1 层)
总结
代码实战
代码模型设计
ResNet 主类结构
网络主体结构
残差层构建
1. _make_layer 函数的作用
2. 四个残差层的具体构建
1. layer1 构建过程
以 layer1 的构建为例:
2. layer2 构建过程
3. layer3 构建过程
4. layer4 构建过程
全局池化与分类器
前向传播流程
ResNet-18 模型工厂函数
网络结构总结
完整代码
基础概念
一、梯度消失:传话游戏里的信息丢失
1. 什么是 “梯度”?
- 类比:学习时的 “错误提示”。比如考试后老师告诉你哪题错了,该怎么改,这就是 “梯度”。
- 在神经网络中:模型预测错了,计算机通过 “反向传播” 算出每个参数该怎么调,这个 “调整方向” 就是梯度。
2. 梯度消失为什么像传话游戏?
-
传话游戏场景:
一群人排成队,第一个人小声说 “今天吃火锅”,传给第二个人,第二个人传给第三个人…… 传到最后一个人时,可能变成 “明天吃面条”。
信息在传递中逐渐丢失。 -
神经网络中的梯度消失:
深层网络就像很长的队伍,梯度(错误提示)从最后一层反向传播到第一层时,可能变得非常微弱(甚至消失),导致底层参数无法更新。
结果:深层网络学不动,性能反而不如浅层网络。
二、ResNet:给信息开 “绿色通道” 的聪明设计
1. ResNet 的核心思想:不从头学,只学 “差异”
-
普通神经网络:像学画画时从头开始画一幅画,难度大。
-
ResNet:像在一幅已有的画上修改(比如把猫改成狗),只学 “差异” 部分,更容易。
-
数学表达: 普通网络:输出 = 复杂函数计算结果 ResNet:输出 = 输入 + 复杂函数计算结果(只学输入和输出的差异)
2. 残差块:信息的 “双车道高速公路”
-
主车道:正常的卷积计算(绕路爬山)
-
应急车道(捷径连接):直接让输入信息 “抄近路” 到输出(直线爬山)
-
类比场景: 你要从山脚到山顶:
- 主车道:绕山路爬(可能累到爬不动)
- 应急车道:坐缆车直接上去(保证能到山顶) ResNet 让信息至少能通过应急车道传递,不会丢失。
3. 为什么 ResNet 能解决梯度消失?
-
普通网络的梯度传递:像水流过狭窄的管道,越流越小,最后没水了(梯度消失)。
-
ResNet 的梯度传递:管道旁边加了个粗水管(捷径连接),即使细管道水流小,粗水管也能保证水流到源头(梯度不消失)。
-
关键公式比喻:
三、ResNet 代码中的 “捷径” 长啥样?
看代码里的残差块:
def forward(self, x):out = 主路径计算(x) # 绕路爬山out += self.shortcut(x) # 加上抄近路的x(坐缆车)return out
- 核心操作:
out += x
,即 “输出 = 主路径结果 + 输入”。 - 效果:如果主路径没学到东西,至少输出等于输入(不会变差);如果主路径学到了东西,就相当于在输入基础上改进。
网络架构
一、输入层:乐高积木的「基础底板」
场景:用乐高搭建机器人,第一步是铺好底板
-
7×7 卷积层(算 1 层)
- 作用:用 7×7 的 “乐高底板”(卷积核)把原始图像(224×224×3)压成 64 个 “特征积木块”(64 通道),尺寸缩小到 112×112。
- 类比:底板上的每个凸起(权重)对应识别一种特征(如边缘、纹理),64 个凸起对应 64 种特征探测器。
- 关键参数:步长 = 2,相当于每铺 2 格积木才放一个凸起,减少后续工作量。
-
BatchNorm+ReLU + 最大池化(不计入层数)
- BatchNorm:把所有积木块的高度标准化(调整数据分布),避免有的太高有的太低。
- ReLU:只保留高度为正的积木块(丢弃无效特征),让后续搭建更高效。
- 3×3 最大池化:用 3×3 的 “筛子” 筛选积木块,只保留每个区域最高的那个(特征最明显),尺寸再缩小到 56×56。
注:这里少画了一个Relu,没啥关系,但是标准的是上面叙述的(图是李沐大神的)
二、残差块:乐高机器人的「模块化关节」
核心:解决「积木搭太高容易散架」的问题(梯度消失)
ResNet18 用的是「Basic Block」,分两种情况:
情况 1:积木尺寸不变(如 layer1 的块)
结构:「双轨道拼接」的关节
-
轨道 2(主轨道):
- 第 1 层:3×3 卷积(像给积木打孔)→ BN(调整孔的大小)→ ReLU(只保留有用的孔)
- 第 2 层:3×3 卷积(再打更深的孔)→ BN(再次调整)
-
轨道 1(跳跃连接):
- 直接把输入积木 “抄近路” 送到轨道末端,和主轨道的积木拼接。
- 类比:如果主轨道的孔打错了,抄近路的积木能保证至少原积木还在,防止信息丢失。
情况 2:积木尺寸缩小(如 layer2-3-4 的块,需要「适配器」)
结构:「轨道 + 适配器」的关节
-
轨道 2(主轨道):
- 第 1 层:3×3 卷积(步长 = 2,直接把积木缩小一半)→ BN→ReLU
- 第 2 层:3×3 卷积(打更深的孔)→ BN
- 同样两层卷积,算 2 层。
-
轨道 1(适配器):
- 用 1×1 卷积(像 “积木压缩器”)调整积木尺寸和通道数,让它能和主轨道的积木匹配。
- 关键点:1×1 卷积有权重,但它属于跳跃连接的一部分,残差块的层数只算主轨道的两层卷积。
三、18 层的准确计数(只数带孔的积木层)
ResNet-18 的 “18 层” 指网络中包含18 个可学习的权重层(不包含池化、激活函数、批量归一化等无参数层)。其核心结构由7×7 卷积层、四个残差块组和全连接层组成,具体分布如下:
1. 初始卷积层(1 层)
- 7×7 卷积:步长 2,填充 3,输出 64 通道。
作用:捕获大感受野,减少空间维度(输入 224×224→112×112)。
2. 四个残差块组(16 层)
每个残差块组由 2 个BasicBlock堆叠而成,每个 BasicBlock 包含 2 个 3×3 卷积层,共4 组 ×2 块 ×2 层 = 16 层。
残差块组结构
-
Stage1(Conv2_x):
- 2 个 BasicBlock,每个块包含:
- 3×3 卷积(步长 1)→ BN → ReLU
- 3×3 卷积(步长 1)→ BN → ReLU
- 输出维度:64 通道,空间尺寸保持 112×112。
- 2 个 BasicBlock,每个块包含:
-
Stage2(Conv3_x):
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2,实现下采样:
- 3×3 卷积(步长 2)→ BN → ReLU → 3×3 卷积(步长 1)→ BN → ReLU
- 输出维度:128 通道,空间尺寸减半至 56×56。
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2,实现下采样:
-
Stage3(Conv4_x):
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2:
- 3×3 卷积(步长 2)→ BN → ReLU → 3×3 卷积(步长 1)→ BN → ReLU
- 输出维度:256 通道,空间尺寸减半至 28×28。
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2:
-
Stage4(Conv5_x):
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2:
- 3×3 卷积(步长 2)→ BN → ReLU → 3×3 卷积(步长 1)→ BN → ReLU
- 输出维度:512 通道,空间尺寸减半至 14×14。
- 2 个 BasicBlock,第一个块的第一个卷积层步长 2:
残差块核心设计
- 跳跃连接:输入直接与卷积输出相加,解决梯度消失问题。
- 1×1 卷积:当输入输出通道数不一致或需要下采样时,用于调整维度(如 Stage2-4 的第一个块)。
3. 全连接层(1 层)
- 全局平均池化:将特征图压缩为 1×1×512。
- 线性层:512 通道→10类(CIFAR-10),最终输出分类结果。
总结
- 总层数:1(初始卷积) + 16(残差块) + 1(全连接) = 18 层。
- 核心优势:通过残差连接和跳跃结构,允许训练极深网络而不退化,同时保持高效的特征提取能力。
代码实战
对于 CIFAR-10 这种小图像数据集,使用 3×3 卷积也是合理的(许多论文和实现中都这样做),因为 7×7 卷积可能会导致过多的信息丢失。
代码模型设计
class BasicBlock(nn.Module):expansion = 1def __init__(self,in_channels,out_channels,stride=1):super(BasicBlock,self).__init__()self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != self.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels,self.expansion * out_channels,kernel_size=1,stride=stride,bias=False),nn.BatchNorm2d(self.expansion * out_channels))def forward(self,x):out = self.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,num_blocks,num_classes=10):super(ResNet,self).__init__()self.in_channels =64self.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.layer1 = self._make_layer(block,64,num_blocks[0],stride =1)self.layer2 = self._make_layer(block,128,num_blocks[1],stride =2)self.layer3 = self._make_layer(block,256,num_blocks[2],stride=2)self.layer4 = self._make_layer(block,512,num_blocks[3],stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512 * block.expansion,num_classes)def _make_layer(self,block,out_channels,num_blocks,stride):stridies = [stride] +[1] * (num_blocks-1)layers = []for stride in stridies:layers.append(block(self.in_channels,out_channels,stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self,x):out = self.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return outdef Resnet18():return ResNet(BasicBlock,[2,2,2,2])
ResNet 主类结构
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_channels = 64
- 初始化参数:
block
:残差块类型(如BasicBlock
)。num_blocks
:各层残差块数量的列表(如[2,2,2,2]
对应 ResNet-18)。num_classes
:分类任务的类别数(默认 10)。
- 初始通道数:
self.in_channels = 64
,作为第一层的输入通道数。
网络主体结构
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
- 初始卷积层:将 3 通道的输入图像(如 RGB)转换为 64 通道的特征图,使用 3×3 卷积核,步长 1,保持尺寸不变。
残差层构建
1. _make_layer
函数的作用
def _make_layer(self, block, out_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)
-
参数:
block
:残差块类型(如BasicBlock
)out_channels
:输出通道数num_blocks
:该层包含的残差块数量stride
:该层第一个残差块的步长
-
核心逻辑:
- 生成步长列表:第一个残差块使用指定的
stride
,后续残差块使用步长 1 - 创建残差块序列:根据步长列表创建残差块,并更新
in_channels
为当前输出通道数
- 生成步长列表:第一个残差块使用指定的
2. 四个残差层的具体构建
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
以 ResNet-18 为例(num_blocks=[2,2,2,2]
),详细分析每一层的构建过程:
1. layer1 构建过程
self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
-
步长列表:strides = [stride] + [1] * (num_blocks - 1) 计算后
[1, 1]
(两个残差块都使用步长 1) -
构建步骤:
- 第一个残差块:
BasicBlock(64, 64, stride=1)
- 输入通道 64,输出通道 64,步长 1
- 跳跃连接:直接连接(因通道数不变且步长 = 1)
- 第二个残差块:
BasicBlock(64, 64, stride=1)
- 输入通道 64(由上一步更新),输出通道 64,步长 1
- 跳跃连接:直接连接
self.in_channels
更新为 64(64 * BasicBlock.expansion = 64
)
- 第一个残差块:
-
输出特征:尺寸不变(32×32),通道数 64
以 layer1
的构建为例:
# 假设block是BasicBlock,num_blocks[0]=2
self.layer1 = self._make_layer(block, 64, 2, stride=1)
_make_layer
函数内部会生成一个包含两个残差块的列表:
layers = [BasicBlock(64, 64, stride=1), # 第一个残差块BasicBlock(64, 64, stride=1) # 第二个残差块
]
然后通过 nn.Sequential(*layers)
将它们组合成一个序列模块,相当于:
self.layer1 = nn.Sequential(BasicBlock(64, 64, stride=1),BasicBlock(64, 64, stride=1)
)
当输入数据通过 layer1
时,会依次经过这两个残差块:
output = self.layer1(input)
# 等价于:
# output = BasicBlock2(BasicBlock1(input))
2. layer2 构建过程
self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
-
步长列表:strides = [stride] + [1] * (num_blocks - 1) 计算后
[2, 1]
(第一个残差块步长 2,第二个步长 1) -
第一次循环(创建第一个残差块)
# 此时self.in_channels=64(来自layer1的输出) layers.append(block(self.in_channels, out_channels, stride)) # 即:layers.append(BasicBlock(64, 128, 2))# 更新self.in_channels为当前输出通道数 self.in_channels = out_channels * block.expansion # BasicBlock.expansion=1,所以self.in_channels=128*1=128
第二次循环(创建第二个残差块)
# 此时self.in_channels=128(上一步更新后) layers.append(block(self.in_channels, out_channels, stride)) # 即:layers.append(BasicBlock(128, 128, 1))# 再次更新self.in_channels=128*1=128(保持不变)
4. 返回组合好的序列模块
return nn.Sequential(*layers) # 等价于: # return nn.Sequential( # BasicBlock(64, 128, 2), # BasicBlock(128, 128, 1) # )
1. 第一个残差块:
BasicBlock(64, 128, 2)
主路径:
# 第一个卷积层(步长2,会使尺寸减半) self.conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.bn1 = nn.BatchNorm2d(128) self.relu = nn.ReLU(inplace=True)# 第二个卷积层(步长1,保持尺寸) self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(128)
跳跃连接:
# 由于stride=2且in_channels=64≠out_channels=128 self.shortcut = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128) )
2. 第二个残差块:
BasicBlock(128, 128, 1)
主路径:
# 两个卷积层均使用步长1 self.conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(128) self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(128)
跳跃连接:
# 由于stride=1且in_channels=128=out_channels=128 self.shortcut = nn.Sequential() # 直接连接
-
输出特征:尺寸减半(16×16),通道数 128
3. layer3 构建过程
self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
-
步长列表:strides = [stride] + [1] * (num_blocks - 1) 计算后
[2, 1]
(第一个残差块步长 2,第二个步长 1) -
构建步骤:
- 第一个残差块:
BasicBlock(128, 256, stride=2)
- 输入通道 128,输出通道 256,步长 2
- 跳跃连接:1×1 卷积(128→256,尺寸减半)
- 第二个残差块:
BasicBlock(256, 256, stride=1)
- 输入通道 256,输出通道 256,步长 1
self.in_channels
更新为 256
- 第一个残差块:
-
输出特征:尺寸减半(8×8),通道数 256
4. layer4 构建过程
self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
-
步长列表:strides = [stride] + [1] * (num_blocks - 1) 计算后
[2, 1]
(第一个残差块步长 2,第二个步长 1) -
构建步骤:
- 第一个残差块:
BasicBlock(256, 512, stride=2)
- 输入通道 256,输出通道 512,步长 2
- 跳跃连接:1×1 卷积(256→512,尺寸减半)
- 第二个残差块:
BasicBlock(512, 512, stride=1)
- 输入通道 512,输出通道 512,步长 1
self.in_channels
更新为 512
- 第一个残差块:
-
输出特征:尺寸减半(4×4),通道数 512
全局池化与分类器
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
- 全局平均池化:将特征图压缩为 1×1 大小,减少参数量。
- 全连接层:将 512 维特征映射到
num_classes
个类别。
前向传播流程
def forward(self, x):out = self.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1) # 修正:原文中的self.view改为out.viewout = self.fc(out)return out
- 数据流向:
- 通过初始卷积层提取基础特征。
- 依次通过 4 个残差层,逐步增加通道数并缩小尺寸。
- 全局池化后展平为一维向量。
- 通过全连接层输出分类结果。
ResNet-18 模型工厂函数
def Resnet18():return ResNet(BasicBlock, [2, 2, 2, 2])
- 参数解释:
- 使用
BasicBlock
(适合浅网络)。 - 每个残差层包含 2 个残差块,共 18 层(1 个初始卷积 + 4×2×2 个残差块 + 1 个全连接)。
- 使用
网络结构总结
层名 | 输出尺寸 | 通道数 | 残差块数 | 说明 |
---|---|---|---|---|
conv1 | H×W | 64 | - | 初始卷积 |
layer1 | H×W | 64 | 2 | 尺寸不变,通道不变 |
layer2 | H/2×W/2 | 128 | 2 | 尺寸减半,通道翻倍 |
layer3 | H/4×W/4 | 256 | 2 | 尺寸减半,通道翻倍 |
layer4 | H/8×W/8 | 512 | 2 | 尺寸减半,通道翻倍 |
avgpool | 1×1 | 512 | - | 全局平均池化 |
fc | - | 10 | - | 分类器 |
完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备:{device}")if torch.cuda.is_available():print(f"Gpu:{torch.cuda.get_device_name(0)}")print(f"内存:{torch.cuda.get_device_properties(0).total_memory /1024/1024}MB")transform_train = transforms.Compose([transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))
])train_dataset = datasets.CIFAR10('./data',train=True,download=True,transform=transform_train)test_dataset = datasets.CIFAR10('./data',train=False,download=True,transform=transform_test)
train_loader = DataLoader(train_dataset,batch_size=128,shuffle=True,num_workers=2) #num_workers=2: 使用 2 个子进程并行加载数据,加快数据读取速度
test_loader = DataLoader(test_dataset,batch_size=100,shuffle=False,num_workers=2)class BasicBlock(nn.Module):expansion = 1def __init__(self,in_channels,out_channels,stride=1):super(BasicBlock,self).__init__()self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != self.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels,self.expansion * out_channels,kernel_size=1,stride=stride,bias=False),nn.BatchNorm2d(self.expansion * out_channels))def forward(self,x):out = self.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,num_blocks,num_classes=10):super(ResNet,self).__init__()self.in_channels =64self.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.layer1 = self._make_layer(block,64,num_blocks[0],stride =1)self.layer2 = self._make_layer(block,128,num_blocks[1],stride =2)self.layer3 = self._make_layer(block,256,num_blocks[2],stride=2)self.layer4 = self._make_layer(block,512,num_blocks[3],stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512 * block.expansion,num_classes)def _make_layer(self,block,out_channels,num_blocks,stride):stridies = [stride] +[1] * (num_blocks-1)layers = []for stride in stridies:layers.append(block(self.in_channels,out_channels,stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self,x):out = self.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return outdef Resnet18():return ResNet(BasicBlock,[2,2,2,2])model = Resnet18().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.01,mometum =0.9,weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=200)best_acc = 0.0
for epoch in range(1, 201):print(f'\nEpoch {epoch}')# 训练阶段model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = loss_function(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()if batch_idx % 100 == 0:print(f'Batch: {batch_idx}/{len(train_loader)}, 'f'Loss: {train_loss / (batch_idx + 1):.3f} | 'f'Acc: {100. * correct / total:.3f}%')# 测试阶段model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = loss_function(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalprint(f'Test Loss: {test_loss / len(test_loader):.3f}, 'f'Accuracy: {acc:.3f}%')scheduler.step()# 保存最佳模型if acc > best_acc:print(f'Saving best model with accuracy: {acc:.3f}%')torch.save(model.state_dict(), ' CIFAR-10.pth')best_acc = acc