MindSpore框架学习项目-ResNet药物分类-构建模型
目录
2.构建模型
2.1定义模型类
2.1.1 基础块ResidualBlockBase
ResidualBlockBase代码解析
2.1.2 瓶颈块ResidualBlock
ResidualBlock代码解释
2.1.3 构建层
构建层代码说明
2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
ResNet组建类代码解析
2.1.5 实例化resnet_xx网络
实例化resnet_xx网络代码分析
2.2模型初始化
模型初始化代码解析
本项目可以在华为云modelart上租一个实例进行,也可以在配置至少为单卡3060的设备上进行
https://console.huaweicloud.com/modelarts/
Ascend环境也适用,但是注意修改device_target参数
需要本地编译器的一些代码传输、修改等可以勾上ssh远程开发
说明:项目使用的数据集来自华为云的数据资源。项目以深度学习任务构建的一般流程展开(数据导入、处理 > 模型选择、构建 > 模型训练 > 模型评估 > 模型优化)。
主线为‘一般流程’,同时代码中会标注出一些要点(# 要点1-1-1:设置使用的设备
)作为支线,帮助学习mindspore框架在进行深度学习任务时一些与pytorch的差异。
可以只看目录中带数字标签的部分来快速查阅代码。
2.构建模型
2.1定义模型类
要求:
补充如下代码的空白处
主要完成:
1. 实现1个卷积层和1个ReLU激活函数的定义
2. 实现ResidualBlockBase和ResidualBlock模块的残差连接,并补全self.layer4的参数
导入mindspore训练环节(包括模型构建、激活函数、反向传播、损失函数等需要的库)
from mindspore import Model
from mindspore import context
import mindspore.ops as ops
from mindspore import Tensor, nn, set_context, GRAPH_MODE, train
from mindspore import load_checkpoint, load_param_into_net
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal
初始化:weight_init = Normal(mean=0, sigma=0.02) 用于初始化卷积层;
gamma_init = Normal(mean=1, sigma=0.02) 用于初始化批归一化层
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
2.1.1 基础块ResidualBlockBase
conv1 和 conv2 的参数设置:
conv1 负责处理输入数据的空间下采样(通过 stride 参数)或通道数变换(通过 out_channels),同时进行第一次特征提取。
conv2 固定为 3×3 卷积,不改变空间尺寸(默认 stride=1),仅对 conv1 的输出进一步提取特征。
(卷积层当池化层用)
class ResidualBlockBase(nn.Cell):
expansion: int = 1def __init__(self, in_channel: int, out_channel: int,
stride: int = 1, norm: Optional[nn.Cell] = None,
down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:
self.norm = nn.BatchNorm2d(out_channel)else:
self.norm = norm# 要点2-1-1:实现1个卷积层和一个ReLU激活函数的定义# 1. Conv2d:# in_channels (int) - Conv2d层输入Tensor的空间维度。# out_channels (int) - Conv2d层输出Tensor的空间维度。# kernel_size (Union[int, tuple[int]]) - 指定二维卷积核的高度和宽度, 卷积核大小为3X3;# stride (Union[int, tuple[int]],可选) - 二维卷积核的移动步长。
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride,
weight_init=weight_init)
self.conv2 = nn.Conv2d(in_channel, out_channel,
kernel_size=3, weight_init=weight_init)# 2. ReLU:逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。需要调用MindSpore的相关API.
self.relu = nn.ReLU()
self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""
identity = x out = self.conv1(x)
out = self.norm(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm(out)if self.down_sample is not None:
identity = self.down_sample(x)# 要点2-1-2: # 1. 实现ResidualBlockBase模块的残差连接
out = out+identity # 输出为主分支与shortcuts之和
out = self.relu(out)return out
ResidualBlockBase代码解析
核心类定义:ResidualBlockBase
作用:实现残差网络的基础块(Basic Block),包含主分支(卷积路径)和短路连接(Shortcut),解决深层网络梯度消失问题。
输入:
in_channel :输入特征图通道数
out_channel :输出特征图通道数
stride :卷积步长(控制特征图尺寸变化,用于下采样)
norm :归一化层(默认使用 BatchNorm2d)
down_sample :下采样模块(用于调整短路连接的维度,确保与主分支输出维度一致)
要点 2-1-1:定义卷积层和 ReLU 激活函数
Conv2d 层实现
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, # 3x3卷积核,步长由参数控制
weight_init=weight_init) # 权重初始化(正态分布,sigma=0.02)
self.conv2 = nn.Conv2d(in_channel, out_channel, # 第二个卷积层输入通道仍为in_channel(残差块基础版)
kernel_size=3, weight_init=weight_init)
关键参数:
kernel_size=3:固定使用 3x3 卷积核,符合残差块基础设计(如 ResNet-18/34 的 Basic Block)。
stride=stride:第一个卷积层的步长由外部传入(用于下采样),第二个卷积层步长固定为 1(保持尺寸)。
weight_init=weight_init:使用正态分布初始化权重(Normal(mean=0, sigma=0.02)),避免梯度爆炸 / 消失。
ReLU 激活函数
self.relu = nn.ReLU() # 直接调用MindSpore的ReLU模块,逐元素计算max(0, x)
作用:引入非线性,避免网络退化为线性层,同时缓解梯度消失。
要点 2-1-2:实现残差连接
if self.down_sample is not None:
identity = self.down_sample(x) # 下采样:调整短路连接的维度(通道数/尺寸)
out = out + identity # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out) # 最后一次ReLU激活,输出非线性特征
残差连接逻辑:
短路连接(Identity Mapping):当输入x的维度(通道数 / 尺寸)与主分支输出out一致时,直接相加(identity = x)。
若维度不一致(如通道数增加或尺寸缩小),通过down_sample模块对x进行下采样(通常是 1x1 卷积 + 步长调整),确保形状匹配。
相加操作:
核心公式:输出 = 主分支输出 + 短路连接,强制保留原始输入信息,使梯度能直接回传至浅层。
激活函数位置:相加后再进行一次 ReLU 激活,确保输出为非线性特征,符合 ResNet 设计规范。
关键模块解析
1. 归一化层(Norm)处理
if not norm:
self.norm = nn.BatchNorm2d(out_channel) # 默认使用BatchNorm2d
else:
self.norm = norm # 支持自定义归一化层(如GroupNorm)
作用:对卷积输出进行归一化,加速训练并提升模型鲁棒性。
位置:每个卷积层后立即接归一化层,再接 ReLU 激活(Conv→Norm→ReLU 顺序)。
2. 下采样模块(down_sample)
self.down_sample = down_sample # 由外部传入,通常是1x1卷积+步长
触发场景:当in_channel ≠ out_channel或stride > 1时,需通过下采样调整短路连接的维度。
典型实现:
down_sample = nn.SequentialCell([
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(out_channel)
])
通过 1x1 卷积调整通道数,步长调整尺寸,确保与主分支输出形状一致。
3. 权重初始化策略
weight_init = Normal(mean=0, sigma=0.02) # 卷积层权重初始化
gamma_init = Normal(mean=1, sigma=0.02) # BatchNorm的γ参数初始化(未在当前代码中使用)
正态分布初始化:较小的标准差(σ=0.02)避免初始权重过大导致激活值饱和,符合深度学习框架的常见实践(如 PyTorch 的默认初始化)。
2.1.2 瓶颈块ResidualBlock
class ResidualBlock(nn.Cell):
expansion = 4 def __init__(self, in_channel: int, out_channel: int,
stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channel, out_channel,
kernel_size=1, weight_init=weight_init)
self.norm1 = nn.BatchNorm2d(out_channel)
self.conv2 = nn.Conv2d(out_channel, out_channel,
kernel_size=3, stride=stride,
weight_init=weight_init)
self.norm2 = nn.BatchNorm2d(out_channel)
self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
kernel_size=1, weight_init=weight_init)
self.norm3 = nn.BatchNorm2d(out_channel * self.expansion) self.relu = nn.ReLU()
self.down_sample = down_sample def construct(self, x): identity = x out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out) if self.down_sample is not None:
identity = self.down_sample(x)
# 2. 实现ResidualBlock模块的残差连接
out = out+identity # 输出为主分支与shortcuts之和
out = self.relu(out) return out
ResidualBlock代码解释
核心类定义:ResidualBlock(瓶颈块)
作用:实现深层残差网络的瓶颈结构,通过 “降维 - 特征提取 - 升维” 减少计算量,支持构建更深的网络(如 50 层以上)。
核心参数:
expansion=4 :升维因子(固定为 4,符合 ResNet 设计规范),即最后一个 1x1 卷积将通道数扩展为out_channel×4。
in_channel :输入特征图通道数
out_channe :中间层特征图通道数(经 1x1 卷积降维后的通道数)
stride :3x3 卷积的步长(控制特征图尺寸变化,用于下采样)
down_sample :下采样模块(调整短路连接的维度,确保与主分支输出维度一致)
要点:瓶颈块的结构设计
瓶颈块通过三层卷积实现 “降维→特征提取→升维”,显著减少计算量(对比基础块的两层 3x3 卷积):
第一层:1x1 卷积(降维)
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, weight_init=weight_init)
作用:将输入通道数从in_channel降为out_channel(如输入 256→输出 64),减少后续 3x3 卷积的计算量。
卷积核大小:1x1,仅改变通道数,不改变特征图尺寸(stride=1,无填充)。
第二层:3x3 卷积(特征提取)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, weight_init=weight_init)
作用:在降维后的低维空间提取空间特征(如边缘、纹理)。
关键参数:stride=stride:支持下采样(如 stride=2 时特征图尺寸减半),由外部传入(用于构建不同 stage 的残差块)。
kernel_size=3:保持 3x3 卷积核,确保感受野与基础块一致。
第三层:1x1 卷积(升维)
self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=1, weight_init=weight_init)
作用:将通道数从out_channel升至out_channel×expansion(如 64→256),与短路连接维度匹配(因 ResNet 的 stage 设计中,输出通道数通常是输入的 4 倍)。
核心公式:输出通道数 = out_channel × expansion(此处expansion=4是 ResNet 瓶颈块的固定设计)。
要点:残差连接与维度匹配
if self.down_sample is not None:
identity = self.down_sample(x) # 调整短路连接的维度
out = out + identity # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out) # 最后一次ReLU激活
触发下采样的场景:
当以下任意条件成立时,需通过down_sample调整短路连接:输入通道数in_channel ≠ 输出通道数out_channel×expansion(升维导致通道数不匹配)。
stride > 1(特征图尺寸缩小,短路连接需同步下采样)。
下采样模块实现(通常由外部传入):
down_sample = nn.SequentialCell([
nn.Conv2d(in_channel, out_channel*self.expansion, kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(out_channel*self.expansion)
])
通过 1x1 卷积调整通道数,步长调整尺寸,确保identity与主分支输出out形状一致(通道数、高度、宽度均相同)。
归一化层与激活函数的顺序
# 每一层的处理流程:Conv → BatchNorm → ReLU
out = self.conv1(x) # 1x1卷积(降维)
out = self.norm1(out) # BatchNorm2d
out = self.relu(out) # ReLU激活
out = self.conv2(out) # 3x3卷积(特征提取)
out = self.norm2(out) # BatchNorm2d
out = self.relu(out) # ReLU激活
out = self.conv3(out) # 1x1卷积(升维)
out = self.norm3(out) # BatchNorm2d(升维后归一化)
设计原则:符合 ResNet 的 “Post-Normalization” 架构,即在卷积后立即归一化,再激活,确保每一层输入处于稳定分布,加速训练收敛。
权重初始化策略
weight_init = Normal(mean=0, sigma=0.02) # 与基础块一致,小方差初始化避免梯度爆炸
作用:对 1x1 和 3x3 卷积的权重进行正态分布初始化,确保初始权重较小,激活值不会因过大输入导致饱和(如 ReLU 的负数区域失活)。
与基础残差块(ResidualBlockBase)的区别
2.1.3 构建层
根据给定参数构建由指定数量残差块组成的网络层,包括处理下采样及层间连接等
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
channel: int, block_nums: int, stride: int = 1):
down_sample = None if stride != 1 or last_out_channel != channel * block.expansion: down_sample = nn.SequentialCell([
nn.Conv2d(last_out_channel, channel * block.expansion,
kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
]) layers = []
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample)) in_channel = channel * block.expansion for _ in range(1, block_nums): layers.append(block(in_channel, channel)) return nn.SequentialCell(layers)
构建层代码说明
功能定位
ResNet 的整体架构就是通过make_layer函数不断堆叠残差块,形成多个 stage,每个 stage 内部保持相同的通道数,相邻 stage 之间通过下采样调整尺寸和通道数,最终构建出深度神经网络。
输入参数:
last_out_channel :上一层输出的特征图通道数(用于判断是否需要下采样)。
block :残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块,通过Type[Union]支持两种类型)。
channel :当前 stage 的基础通道数(瓶颈块中为降维后的通道数,基础块中为输出通道数)。
block_nums :当前 stage 包含的残差块数量(如 ResNet-50 的每个 stage 包含 3/4/6/3 个瓶颈块)。
stride :当前 stage 第一个残差块的卷积步长(控制下采样,默认 1 表示不采样)。
输出:
由多个残差块组成的nn.SequentialCell序列(可直接作为网络的一个 stage,如 ResNet 的layer1、layer2等)。
核心代码逻辑解析
1. 下采样模块(down_sample)的条件判断与创建(核心考点)
if stride != 1 or last_out_channel != channel * block.expansion:
down_sample = nn.SequentialCell([
nn.Conv2d(last_out_channel, channel * block.expansion,
kernel_size=1, stride=stride, weight_init=weight_init),
nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
])
触发条件(满足任意一条即需下采样):
stride != 1:需要对特征图尺寸进行下采样(如 stride=2 时尺寸减半)。
last_out_channel != channel * block.expansion:输入通道数与当前 stage 输出通道数不一致(瓶颈块中输出通道是channel×4,基础块中是channel×1)。
下采样实现:
通过1x1 卷积调整通道数(从last_out_channel到channel×block.expansion)。
卷积步长设为stride,同步调整特征图尺寸(与主分支的 3x3 卷积步长一致)。
接 BatchNorm 层归一化,确保短路连接的输出分布稳定。
核心作用:保证短路连接(identity)的维度与主分支输出一致,使out + identity操作可行。
2. 残差块序列的构建
layers = []
# 添加第一个残差块(可能包含下采样)
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))
# 更新输入通道为当前stage的输出通道(block.expansion倍)
in_channel = channel * block.expansion
# 添加后续残差块(无下采样,stride=1,通道数已对齐)
for _ in range(1, block_nums):
layers.append(block(in_channel, channel)) # 输入通道为上一个块的输出通道
第一个块的特殊性:
传入stride和down_sample,处理当前 stage 的下采样和通道对齐(如 ResNet 中layer2的第一个块 stride=2,实现尺寸减半)。
若无需下采样(stride=1 且通道数匹配),down_sample=None,短路连接直接使用输入x。
后续块的一致性:
输入通道in_channel固定为channel×block.expansion(即上一个块的输出通道)。
不再传入stride(默认 1)和down_sample(无需下采样,通道数已对齐),所有后续块仅做恒等残差连接。
2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops
class ResNet(nn.Cell):
def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
layer_nums: List[int], num_classes: int, input_channel: int) -> None:
super(ResNet, self).__init__() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
self.norm = nn.BatchNorm2d(64)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = make_layer(64, block, 64, layer_nums[0])
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 要点2-1-3:layer4的输出通道参数‘512’的含义
self.avg_pool = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) def construct(self, x): x = self.conv1(x)
x = self.norm(x)
x = self.relu(x)
x = self.max_pool(x) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avg_pool(x,(2,3)) x = self.flatten(x)
x = self.fc(x) return x
ResNet组建类代码解析
1. 类定义与核心参数
class ResNet(nn.Cell):
def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
layer_nums: List[int], num_classes: int, input_channel: int) -> None:
super(ResNet, self).__init__()
block:残差块类型(基础块ResidualBlockBase或瓶颈块ResidualBlock),决定网络层数和计算复杂度。
layer_nums:各 stage 的残差块数量(如[3, 4, 6, 3]对应 ResNet-50)。
num_classes:分类任务的类别数(如中药材分类的 12 类)。
input_channel:全连接层输入通道数(由最后一个 stage 的输出通道决定,如瓶颈块下为512×4=2048)。
2. 主干网络结构
输入层与初始特征提取
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init) # 7x7卷积
self.norm = nn.BatchNorm2d(64) # 归一化
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') # 最大池化
7x7 卷积:输入 3 通道(RGB 图像),输出 64 通道,步长 2,初步提取特征并降采样(尺寸减半)。
最大池化:核大小 3x3,步长 2,pad_mode='same'保持空间尺寸对称减半(如 224→112→56)。
四个 stage(layer1-layer4)
self.layer1 = make_layer(64, block, 64, layer_nums[0]) # stride=1(默认)
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2) # 下采样
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 关键考点:补全layer4参数
make_layer功能:动态构建残差块序列,每个 stage:第一个块通过stride=2实现下采样(layer2-layer4),通道数翻倍(如 64→128→256→512)。
block.expansion控制通道升维(基础块 = 1,瓶颈块 = 4),例如瓶颈块下64×4=256作为下 stage 输入。
输出层
self.avg_pool = ops.ReduceMean(keep_dims=True) # 全局平均池化
self.flatten = nn.Flatten() # 展平特征
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) # 全连接分类
全局平均池化:替代全连接层前的全连接操作,减少参数数量,输出特征图尺寸为(batch, 512×expansion, 1, 1)。
全连接层:将特征映射到num_classes维空间,输出分类概率。
3. 前向传播逻辑
def construct(self, x):
x = self.conv1(x) → self.norm(x) → self.relu(x) → self.max_pool(x) # 初始特征提取
x = self.layer1(x) → self.layer2(x) → self.layer3(x) → self.layer4(x) # 四级残差块特征提取
x = self.avg_pool(x, (2, 3)) # 对空间维度(H=2, W=3,假设输入为7x7)做平均池化
x = self.flatten(x) # 展平为一维向量(shape: [batch, input_channel])
x = self.fc(x) # 分类输出
return x
空间尺寸变化:假设输入 224x224,经过conv1(stride=2)和max_pool(stride=2)后尺寸为 56x56,每层 stage 若stride=2则尺寸减半(56→28→14→7),最终layer4输出 7x7。
通道数变化:随 stage 递增(64→128→256→512),经block.expansion后瓶颈块通道数为 256→512→1024→2048。
4. 核心要点与设计原则
layer4参数补全(题目要求):输入通道为256×block.expansion(上一 stage 输出),当前 stage 基础通道512,stride=2实现最后一次下采样。
残差块类型兼容性:通过block参数支持基础块(浅层)和瓶颈块(深层),expansion自动适配通道逻辑(无需为不同块编写独立代码)。
下采样策略:每个 stage 的第一个块通过stride=2和1x1卷积调整通道 / 尺寸,保证残差连接维度匹配。
计算效率:瓶颈块通过1x1卷积降维减少 3x3 卷积计算量,使深层网络(如 ResNet-152)训练可行。
5. 代码关键作用
模块化构建:通过make_layer和残差块组合,快速搭建不同深度的 ResNet(如 50 层、101 层)。
特征提取流程:从浅层边缘检测到深层语义特征,逐层抽象,适应图像分类任务。
维度匹配:自动处理残差连接的通道和尺寸对齐,避免手动计算错误。
2.1.5 实例化resnet_xx网络
实例化resnet50
def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
input_channel: int):
model = ResNet(block, layers, num_classes, input_channel)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):
resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)
实例化resnet_xx网络代码分析
ps:代码中‘ return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)’
为什么能跨函数识别到 pretrained参数?
在 Python 中,这是因为作用域的规则 。在resnet50函数中,pretrained是该函数的参数,属于局部作用域。当调用_resnet函数时,pretrained作为参数传递给_resnet函数,所以_resnet函数能够识别并使用这个参数
1. 通用模型构建函数 _resnet
def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
input_channel: int):
model = ResNet(block, layers, num_classes, input_channel)return model
功能:通用 ResNet 模型构建接口,通过参数化残差块类型、层数、分类数等,灵活生成不同配置的 ResNet 模型。
参数解析:block:残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块)。
layers:各 stage 的残差块数量列表(如[3,4,6,3]对应 ResNet-50 的四个 stage)。
num_classes:分类任务的类别数(如中药材的 12 类)。
pretrained:是否加载预训练权重(布尔值,True表示加载)。
pretrained_ckpt:预训练权重文件路径(如"./LoadPretrainedModel/resnet50_224_new.ckpt")。
input_channel:全连接层输入维度(由最后一个 stage 的输出通道决定,如 ResNet-50 为 2048)。
2. 特定模型:ResNet-50 的封装 resnet50
def resnet50(num_classes: int = 1000, pretrained: bool = False):
resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
pretrained, resnet50_ckpt, 2048)
功能:直接生成 ResNet-50 模型,固定了 ResNet-50 的核心配置(瓶颈块、各 stage 块数、输入通道)。
关键参数固定值:block=ResidualBlock:使用瓶颈块(Bottleneck Block),适用于深层网络(ResNet-50/101/152)。
layers=[3,4,6,3]:ResNet-50 的标准配置(四个 stage 分别包含 3、4、6、3 个瓶颈块)。
input_channel=2048:最后一个 stage 输出通道数(512 基础通道 × 瓶颈块 expansion=4)。
预训练支持:resnet50_ckpt指定了预训练权重路径(如用户需要加载 ImageNet 预训练权重,可通过pretrained=True启用)。
3. 代码设计核心价值
模块化与复用性:
_resnet作为通用构建函数,通过参数化残差块类型和层数,可扩展生成 ResNet-18(基础块 +[2,2,2,2])、ResNet-101([3,4,23,3])等变体,避免重复代码。用户友好性:
resnet50函数封装了 ResNet-50 的具体配置,用户只需指定num_classes(分类数)和pretrained(是否加载预训练),即可快速获取模型,降低使用门槛。
2.2模型初始化
要求:
对定义的ResNet50模型进行实例化
实例化一个用于12分类的resnet50模型
# 要点2-2-1: 对定义的ResNet50模型进行实例化
network = resnet50(num_classes=12)
num_class = 12
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
network.fc = fcfor param in network.get_parameters():
param.requires_grad = True
模型初始化代码解析
1. 实例化 ResNet50 模型
network = resnet50(num_classes=12)
作用:调用resnet50函数创建 ResNet50 模型实例,指定分类数为 12(如中药材的 12 类)。
内部逻辑:
resnet50函数通过_resnet生成 ResNet50 模型,默认使用瓶颈块(ResidualBlock)和标准层数[3,4,6,3],并将原 1000 类的全连接层(fc 层)初始化为 12 类输出(但需后续调整,见下文)。2. 替换全连接层适配新任务
num_class = 12 # 新任务的分类数(如中药材的12类)
in_channel = network.fc.in_channels # 获取原fc层的输入通道数(ResNet50为2048)
fc = nn.Dense(in_channels=in_channel, out_channels=num_class) # 新建12类输出的全连接层
network.fc = fc # 替换原模型的fc层
背景:预训练 ResNet50 的 fc 层通常输出 1000 类(ImageNet 任务),需替换为新任务的分类数(12 类)。
关键操作:获取原 fc 层输入维度(in_channel=2048,由 ResNet50 的全局平均池化输出决定)。
新建全连接层fc,输入维度保持 2048,输出维度改为 12。
替换原模型的 fc 层,完成模型输出适配。
3. 启用所有参数训练 -- 全量微调
for param in network.get_parameters():
param.requires_grad = True
作用:将模型所有参数的梯度计算标志(requires_grad)设为True,允许训练时更新所有参数。
场景意义:若使用预训练模型(pretrained=True),此操作表示 “端到端微调”(所有层参数均参与训练),适合新数据集与预训练数据分布差异较大的场景(如中药材分类 vs ImageNet 通用分类)。
若未使用预训练(pretrained=False),则模型从头开始训练,所有参数自然需要梯度更新。