当前位置: 首页 > news >正文

MindSpore框架学习项目-ResNet药物分类-构建模型

目录

2.构建模型

2.1定义模型类

2.1.1 基础块ResidualBlockBase

ResidualBlockBase代码解析

2.1.2 瓶颈块ResidualBlock

ResidualBlock代码解释

2.1.3 构建层

构建层代码说明

2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现

ResNet组建类代码解析

2.1.5 实例化resnet_xx网络

实例化resnet_xx网络代码分析

2.2模型初始化

模型初始化代码解析


参考内容: 昇思MindSpore | 全场景AI框架 | 昇思MindSpore社区官网 华为自研的国产AI框架,训推一体,支持动态图、静态图,全场景适用,有着不错的生态

本项目可以在华为云modelart上租一个实例进行,也可以在配置至少为单卡3060的设备上进行

https://console.huaweicloud.com/modelarts/

Ascend环境也适用,但是注意修改device_target参数

需要本地编译器的一些代码传输、修改等可以勾上ssh远程开发

说明:项目使用的数据集来自华为云的数据资源。项目以深度学习任务构建的一般流程展开(数据导入、处理 > 模型选择、构建 > 模型训练 > 模型评估 > 模型优化)。

主线为‘一般流程’,同时代码中会标注出一些要点(# 要点1-1-1:设置使用的设备

)作为支线,帮助学习mindspore框架在进行深度学习任务时一些与pytorch的差异。

可以只看目录中带数字标签的部分来快速查阅代码。

2.构建模型

2.1定义模型类

要求:

补充如下代码的空白处

主要完成:

1. 实现1个卷积层和1个ReLU激活函数的定义

2. 实现ResidualBlockBase和ResidualBlock模块的残差连接,并补全self.layer4的参数

导入mindspore训练环节(包括模型构建、激活函数、反向传播、损失函数等需要的库)

from mindspore import Model
from mindspore import context
import mindspore.ops as ops
from mindspore import Tensor, nn, set_context, GRAPH_MODE, train
from mindspore import load_checkpoint, load_param_into_net
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal

初始化:weight_init = Normal(mean=0, sigma=0.02) 用于初始化卷积层;

gamma_init = Normal(mean=1, sigma=0.02) 用于初始化批归一化层

weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

2.1.1 基础块ResidualBlockBase

conv1 和 conv2 的参数设置:

conv1 负责处理输入数据的空间下采样(通过 stride 参数)或通道数变换(通过 out_channels),同时进行第一次特征提取。

conv2 固定为 3×3 卷积,不改变空间尺寸(默认 stride=1),仅对 conv1 的输出进一步提取特征。

(卷积层当池化层用)

class ResidualBlockBase(nn.Cell):
    expansion: int = 1def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:
            self.norm = nn.BatchNorm2d(out_channel)else:
            self.norm = norm# 要点2-1-1:实现1个卷积层和一个ReLU激活函数的定义# 1. Conv2d:#    in_channels (int) - Conv2d层输入Tensor的空间维度。#    out_channels (int) - Conv2d层输出Tensor的空间维度。#    kernel_size (Union[int, tuple[int]]) - 指定二维卷积核的高度和宽度, 卷积核大小为3X3;#    stride (Union[int, tuple[int]],可选) - 二维卷积核的移动步长。
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)# 2. ReLU:逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。需要调用MindSpore的相关API.
        self.relu = nn.ReLU()
        self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""
        identity = x        out = self.conv1(x)
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm(out)if self.down_sample is not None:
            identity = self.down_sample(x)# 要点2-1-2: # 1. 实现ResidualBlockBase模块的残差连接
        out = out+identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)return out

ResidualBlockBase代码解析

核心类定义:ResidualBlockBase

作用:实现残差网络的基础块(Basic Block),包含主分支(卷积路径)和短路连接(Shortcut),解决深层网络梯度消失问题。

输入:

in_channel :输入特征图通道数

out_channel :输出特征图通道数

stride :卷积步长(控制特征图尺寸变化,用于下采样)

norm :归一化层(默认使用 BatchNorm2d)

down_sample :下采样模块(用于调整短路连接的维度,确保与主分支输出维度一致)

要点 2-1-1:定义卷积层和 ReLU 激活函数

Conv2d 层实现

self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                       kernel_size=3, stride=stride,  # 3x3卷积核,步长由参数控制
                       weight_init=weight_init)     # 权重初始化(正态分布,sigma=0.02)
self.conv2 = nn.Conv2d(in_channel, out_channel,  # 第二个卷积层输入通道仍为in_channel(残差块基础版)
                       kernel_size=3, weight_init=weight_init)

关键参数:

kernel_size=3:固定使用 3x3 卷积核,符合残差块基础设计(如 ResNet-18/34 的 Basic Block)。

stride=stride:第一个卷积层的步长由外部传入(用于下采样),第二个卷积层步长固定为 1(保持尺寸)。

weight_init=weight_init:使用正态分布初始化权重(Normal(mean=0, sigma=0.02)),避免梯度爆炸 / 消失。

ReLU 激活函数

self.relu = nn.ReLU()  # 直接调用MindSpore的ReLU模块,逐元素计算max(0, x)

作用:引入非线性,避免网络退化为线性层,同时缓解梯度消失。

要点 2-1-2:实现残差连接

if self.down_sample is not None:
    identity = self.down_sample(x)  # 下采样:调整短路连接的维度(通道数/尺寸)
out = out + identity  # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out)   # 最后一次ReLU激活,输出非线性特征

残差连接逻辑:

短路连接(Identity Mapping):当输入x的维度(通道数 / 尺寸)与主分支输出out一致时,直接相加(identity = x)。

若维度不一致(如通道数增加或尺寸缩小),通过down_sample模块对x进行下采样(通常是 1x1 卷积 + 步长调整),确保形状匹配。

相加操作:

核心公式:输出 = 主分支输出 + 短路连接,强制保留原始输入信息,使梯度能直接回传至浅层。

激活函数位置:相加后再进行一次 ReLU 激活,确保输出为非线性特征,符合 ResNet 设计规范。

关键模块解析

1. 归一化层(Norm)处理

if not norm:
    self.norm = nn.BatchNorm2d(out_channel)  # 默认使用BatchNorm2d
else:
    self.norm = norm  # 支持自定义归一化层(如GroupNorm)

作用:对卷积输出进行归一化,加速训练并提升模型鲁棒性。

位置:每个卷积层后立即接归一化层,再接 ReLU 激活(Conv→Norm→ReLU 顺序)。

2. 下采样模块(down_sample)

self.down_sample = down_sample  # 由外部传入,通常是1x1卷积+步长

触发场景:当in_channel ≠ out_channel或stride > 1时,需通过下采样调整短路连接的维度。

典型实现:

down_sample = nn.SequentialCell([
    nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, weight_init=weight_init),
    nn.BatchNorm2d(out_channel)
])

通过 1x1 卷积调整通道数,步长调整尺寸,确保与主分支输出形状一致。

3. 权重初始化策略

weight_init = Normal(mean=0, sigma=0.02)  # 卷积层权重初始化
gamma_init = Normal(mean=1, sigma=0.02)    # BatchNorm的γ参数初始化(未在当前代码中使用)

正态分布初始化:较小的标准差(σ=0.02)避免初始权重过大导致激活值饱和,符合深度学习框架的常见实践(如 PyTorch 的默认初始化)。

2.1.2 瓶颈块ResidualBlock
class ResidualBlock(nn.Cell):
    expansion = 4    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)        self.relu = nn.ReLU()
        self.down_sample = down_sample    def construct(self, x):        identity = x        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.norm3(out)        if self.down_sample is not None:
            identity = self.down_sample(x)
        # 2. 实现ResidualBlock模块的残差连接
        out = out+identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)        return out

ResidualBlock代码解释

核心类定义:ResidualBlock(瓶颈块)

作用:实现深层残差网络的瓶颈结构,通过 “降维 - 特征提取 - 升维” 减少计算量,支持构建更深的网络(如 50 层以上)。

核心参数:

expansion=4 :升维因子(固定为 4,符合 ResNet 设计规范),即最后一个 1x1 卷积将通道数扩展为out_channel×4。

in_channel :输入特征图通道数

out_channe :中间层特征图通道数(经 1x1 卷积降维后的通道数)

stride :3x3 卷积的步长(控制特征图尺寸变化,用于下采样)

down_sample :下采样模块(调整短路连接的维度,确保与主分支输出维度一致)

要点:瓶颈块的结构设计

瓶颈块通过三层卷积实现 “降维→特征提取→升维”,显著减少计算量(对比基础块的两层 3x3 卷积):

第一层:1x1 卷积(降维)

self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, weight_init=weight_init)

作用:将输入通道数从in_channel降为out_channel(如输入 256→输出 64),减少后续 3x3 卷积的计算量。

卷积核大小:1x1,仅改变通道数,不改变特征图尺寸(stride=1,无填充)。

第二层:3x3 卷积(特征提取)

self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, weight_init=weight_init)

作用:在降维后的低维空间提取空间特征(如边缘、纹理)。

关键参数:stride=stride:支持下采样(如 stride=2 时特征图尺寸减半),由外部传入(用于构建不同 stage 的残差块)。

kernel_size=3:保持 3x3 卷积核,确保感受野与基础块一致。

第三层:1x1 卷积(升维)

self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=1, weight_init=weight_init)

作用:将通道数从out_channel升至out_channel×expansion(如 64→256),与短路连接维度匹配(因 ResNet 的 stage 设计中,输出通道数通常是输入的 4 倍)。

核心公式:输出通道数 = out_channel × expansion(此处expansion=4是 ResNet 瓶颈块的固定设计)。

要点:残差连接与维度匹配

if self.down_sample is not None:
    identity = self.down_sample(x)  # 调整短路连接的维度
out = out + identity  # 残差连接核心:主分支输出与短路连接相加
out = self.relu(out)   # 最后一次ReLU激活

触发下采样的场景:

当以下任意条件成立时,需通过down_sample调整短路连接:输入通道数in_channel ≠ 输出通道数out_channel×expansion(升维导致通道数不匹配)。

stride > 1(特征图尺寸缩小,短路连接需同步下采样)。

下采样模块实现(通常由外部传入):

down_sample = nn.SequentialCell([
    nn.Conv2d(in_channel, out_channel*self.expansion, kernel_size=1, stride=stride, weight_init=weight_init),
    nn.BatchNorm2d(out_channel*self.expansion)
])

通过 1x1 卷积调整通道数,步长调整尺寸,确保identity与主分支输出out形状一致(通道数、高度、宽度均相同)。

归一化层与激活函数的顺序

# 每一层的处理流程:Conv → BatchNorm → ReLU
out = self.conv1(x)       # 1x1卷积(降维)
out = self.norm1(out)     # BatchNorm2d
out = self.relu(out)      # ReLU激活
out = self.conv2(out)     # 3x3卷积(特征提取)
out = self.norm2(out)     # BatchNorm2d
out = self.relu(out)      # ReLU激活
out = self.conv3(out)     # 1x1卷积(升维)
out = self.norm3(out)     # BatchNorm2d(升维后归一化)

设计原则:符合 ResNet 的 “Post-Normalization” 架构,即在卷积后立即归一化,再激活,确保每一层输入处于稳定分布,加速训练收敛。

权重初始化策略

weight_init = Normal(mean=0, sigma=0.02)  # 与基础块一致,小方差初始化避免梯度爆炸

作用:对 1x1 和 3x3 卷积的权重进行正态分布初始化,确保初始权重较小,激活值不会因过大输入导致饱和(如 ReLU 的负数区域失活)。

与基础残差块(ResidualBlockBase)的区别

2.1.3 构建层

根据给定参数构建由指定数量残差块组成的网络层,包括处理下采样及层间连接等

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None    if stride != 1 or last_out_channel != channel * block.expansion:        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))    in_channel = channel * block.expansion    for _ in range(1, block_nums):        layers.append(block(in_channel, channel))    return nn.SequentialCell(layers)

构建层代码说明

功能定位

ResNet 的整体架构就是通过make_layer函数不断堆叠残差块,形成多个 stage,每个 stage 内部保持相同的通道数,相邻 stage 之间通过下采样调整尺寸和通道数,最终构建出深度神经网络。

输入参数:

last_out_channel :上一层输出的特征图通道数(用于判断是否需要下采样)。

block :残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块,通过Type[Union]支持两种类型)。

channel :当前 stage 的基础通道数(瓶颈块中为降维后的通道数,基础块中为输出通道数)。

block_nums :当前 stage 包含的残差块数量(如 ResNet-50 的每个 stage 包含 3/4/6/3 个瓶颈块)。

stride :当前 stage 第一个残差块的卷积步长(控制下采样,默认 1 表示不采样)。

输出:

由多个残差块组成的nn.SequentialCell序列(可直接作为网络的一个 stage,如 ResNet 的layer1、layer2等)。

核心代码逻辑解析

1. 下采样模块(down_sample)的条件判断与创建(核心考点)

if stride != 1 or last_out_channel != channel * block.expansion:
    down_sample = nn.SequentialCell([
        nn.Conv2d(last_out_channel, channel * block.expansion,
                  kernel_size=1, stride=stride, weight_init=weight_init),
        nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
    ])

触发条件(满足任意一条即需下采样):

stride != 1:需要对特征图尺寸进行下采样(如 stride=2 时尺寸减半)。

last_out_channel != channel * block.expansion:输入通道数与当前 stage 输出通道数不一致(瓶颈块中输出通道是channel×4,基础块中是channel×1)。

下采样实现:

通过1x1 卷积调整通道数(从last_out_channel到channel×block.expansion)。

卷积步长设为stride,同步调整特征图尺寸(与主分支的 3x3 卷积步长一致)。

接 BatchNorm 层归一化,确保短路连接的输出分布稳定。

核心作用:保证短路连接(identity)的维度与主分支输出一致,使out + identity操作可行。

2. 残差块序列的构建

layers = []
# 添加第一个残差块(可能包含下采样)
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))
# 更新输入通道为当前stage的输出通道(block.expansion倍)
in_channel = channel * block.expansion
# 添加后续残差块(无下采样,stride=1,通道数已对齐)
for _ in range(1, block_nums):
    layers.append(block(in_channel, channel))  # 输入通道为上一个块的输出通道

第一个块的特殊性:

传入stride和down_sample,处理当前 stage 的下采样和通道对齐(如 ResNet 中layer2的第一个块 stride=2,实现尺寸减半)。

若无需下采样(stride=1 且通道数匹配),down_sample=None,短路连接直接使用输入x。

后续块的一致性:

输入通道in_channel固定为channel×block.expansion(即上一个块的输出通道)。

不再传入stride(默认 1)和down_sample(无需下采样,通道数已对齐),所有后续块仅做恒等残差连接。

2.1.4 定义不同组合(block,layer_nums)的ResNet网络实现
from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) # 要点2-1-3:layer4的输出通道参数‘512’的含义
        self.avg_pool = ops.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)    def construct(self, x):        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)        x = self.avg_pool(x,(2,3))        x = self.flatten(x)
        x = self.fc(x)        return x

ResNet组建类代码解析

1. 类定义与核心参数

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

block:残差块类型(基础块ResidualBlockBase或瓶颈块ResidualBlock),决定网络层数和计算复杂度。

layer_nums:各 stage 的残差块数量(如[3, 4, 6, 3]对应 ResNet-50)。

num_classes:分类任务的类别数(如中药材分类的 12 类)。

input_channel:全连接层输入通道数(由最后一个 stage 的输出通道决定,如瓶颈块下为512×4=2048)。

2. 主干网络结构

输入层与初始特征提取

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)  # 7x7卷积
self.norm = nn.BatchNorm2d(64)  # 归一化
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')  # 最大池化

7x7 卷积:输入 3 通道(RGB 图像),输出 64 通道,步长 2,初步提取特征并降采样(尺寸减半)。

最大池化:核大小 3x3,步长 2,pad_mode='same'保持空间尺寸对称减半(如 224→112→56)。

四个 stage(layer1-layer4)

self.layer1 = make_layer(64, block, 64, layer_nums[0])  # stride=1(默认)
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)  # 下采样
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)  # 关键考点:补全layer4参数

make_layer功能:动态构建残差块序列,每个 stage:第一个块通过stride=2实现下采样(layer2-layer4),通道数翻倍(如 64→128→256→512)。

block.expansion控制通道升维(基础块 = 1,瓶颈块 = 4),例如瓶颈块下64×4=256作为下 stage 输入。

输出层

self.avg_pool = ops.ReduceMean(keep_dims=True)  # 全局平均池化
self.flatten = nn.Flatten()  # 展平特征
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)  # 全连接分类

全局平均池化:替代全连接层前的全连接操作,减少参数数量,输出特征图尺寸为(batch, 512×expansion, 1, 1)。

全连接层:将特征映射到num_classes维空间,输出分类概率。

3. 前向传播逻辑

def construct(self, x):
    x = self.conv1(x) → self.norm(x) → self.relu(x) → self.max_pool(x)  # 初始特征提取
    x = self.layer1(x) → self.layer2(x) → self.layer3(x) → self.layer4(x)  # 四级残差块特征提取
    x = self.avg_pool(x, (2, 3))  # 对空间维度(H=2, W=3,假设输入为7x7)做平均池化
    x = self.flatten(x)  # 展平为一维向量(shape: [batch, input_channel])
    x = self.fc(x)  # 分类输出
    return x

空间尺寸变化:假设输入 224x224,经过conv1(stride=2)和max_pool(stride=2)后尺寸为 56x56,每层 stage 若stride=2则尺寸减半(56→28→14→7),最终layer4输出 7x7。

通道数变化:随 stage 递增(64→128→256→512),经block.expansion后瓶颈块通道数为 256→512→1024→2048。

4. 核心要点与设计原则

layer4参数补全(题目要求):输入通道为256×block.expansion(上一 stage 输出),当前 stage 基础通道512,stride=2实现最后一次下采样。

残差块类型兼容性:通过block参数支持基础块(浅层)和瓶颈块(深层),expansion自动适配通道逻辑(无需为不同块编写独立代码)。

下采样策略:每个 stage 的第一个块通过stride=2和1x1卷积调整通道 / 尺寸,保证残差连接维度匹配。

计算效率:瓶颈块通过1x1卷积降维减少 3x3 卷积计算量,使深层网络(如 ResNet-152)训练可行。

5. 代码关键作用

模块化构建:通过make_layer和残差块组合,快速搭建不同深度的 ResNet(如 50 层、101 层)。

特征提取流程:从浅层边缘检测到深层语义特征,逐层抽象,适应图像分类任务。

维度匹配:自动处理残差连接的通道和尺寸对齐,避免手动计算错误。

2.1.5 实例化resnet_xx网络

实例化resnet50

def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

实例化resnet_xx网络代码分析

ps:代码中‘ return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,

pretrained, resnet50_ckpt, 2048)’

为什么能跨函数识别到 pretrained参数?

在 Python 中,这是因为作用域的规则 。在resnet50函数中,pretrained是该函数的参数,属于局部作用域。当调用_resnet函数时,pretrained作为参数传递给_resnet函数,所以_resnet函数能够识别并使用这个参数

1. 通用模型构建函数 _resnet

def _resnet(block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)return model

功能:通用 ResNet 模型构建接口,通过参数化残差块类型、层数、分类数等,灵活生成不同配置的 ResNet 模型。

参数解析:block:残差块类型(ResidualBlockBase基础块或ResidualBlock瓶颈块)。

layers:各 stage 的残差块数量列表(如[3,4,6,3]对应 ResNet-50 的四个 stage)。

num_classes:分类任务的类别数(如中药材的 12 类)。

pretrained:是否加载预训练权重(布尔值,True表示加载)。

pretrained_ckpt:预训练权重文件路径(如"./LoadPretrainedModel/resnet50_224_new.ckpt")。

input_channel:全连接层输入维度(由最后一个 stage 的输出通道决定,如 ResNet-50 为 2048)。

2. 特定模型:ResNet-50 的封装 resnet50

def resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

功能:直接生成 ResNet-50 模型,固定了 ResNet-50 的核心配置(瓶颈块、各 stage 块数、输入通道)。

关键参数固定值:block=ResidualBlock:使用瓶颈块(Bottleneck Block),适用于深层网络(ResNet-50/101/152)。

layers=[3,4,6,3]:ResNet-50 的标准配置(四个 stage 分别包含 3、4、6、3 个瓶颈块)。

input_channel=2048:最后一个 stage 输出通道数(512 基础通道 × 瓶颈块 expansion=4)。

预训练支持:resnet50_ckpt指定了预训练权重路径(如用户需要加载 ImageNet 预训练权重,可通过pretrained=True启用)。

3. 代码设计核心价值

模块化与复用性:

_resnet作为通用构建函数,通过参数化残差块类型和层数,可扩展生成 ResNet-18(基础块 +[2,2,2,2])、ResNet-101([3,4,23,3])等变体,避免重复代码。用户友好性:

resnet50函数封装了 ResNet-50 的具体配置,用户只需指定num_classes(分类数)和pretrained(是否加载预训练),即可快速获取模型,降低使用门槛。

2.2模型初始化

要求:

对定义的ResNet50模型进行实例化

实例化一个用于12分类的resnet50模型

# 要点2-2-1: 对定义的ResNet50模型进行实例化
network = resnet50(num_classes=12)
num_class = 12
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
network.fc = fcfor param in network.get_parameters():
    param.requires_grad = True

模型初始化代码解析

1. 实例化 ResNet50 模型

network = resnet50(num_classes=12)

作用:调用resnet50函数创建 ResNet50 模型实例,指定分类数为 12(如中药材的 12 类)。

内部逻辑:

resnet50函数通过_resnet生成 ResNet50 模型,默认使用瓶颈块(ResidualBlock)和标准层数[3,4,6,3],并将原 1000 类的全连接层(fc 层)初始化为 12 类输出(但需后续调整,见下文)。2. 替换全连接层适配新任务

num_class = 12  # 新任务的分类数(如中药材的12类)
in_channel = network.fc.in_channels  # 获取原fc层的输入通道数(ResNet50为2048)
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)  # 新建12类输出的全连接层
network.fc = fc  # 替换原模型的fc层

背景:预训练 ResNet50 的 fc 层通常输出 1000 类(ImageNet 任务),需替换为新任务的分类数(12 类)。

关键操作:获取原 fc 层输入维度(in_channel=2048,由 ResNet50 的全局平均池化输出决定)。

新建全连接层fc,输入维度保持 2048,输出维度改为 12。

替换原模型的 fc 层,完成模型输出适配。

3. 启用所有参数训练 -- 全量微调

for param in network.get_parameters():
    param.requires_grad = True

作用:将模型所有参数的梯度计算标志(requires_grad)设为True,允许训练时更新所有参数。

场景意义:若使用预训练模型(pretrained=True),此操作表示 “端到端微调”(所有层参数均参与训练),适合新数据集与预训练数据分布差异较大的场景(如中药材分类 vs ImageNet 通用分类)。

若未使用预训练(pretrained=False),则模型从头开始训练,所有参数自然需要梯度更新。

相关文章:

  • 卷积神经网络实战(4)代码详解
  • 把Excel数据文件导入到Oracle数据库
  • k8s之statefulset
  • 低成本自动化改造的18个技术锚点深度解析
  • go语言封装、继承与多态:
  • 生信服务器如何安装cellranger|生信服务器安装软件|单细胞测序软件安装
  • K8S - Harbor 镜像仓库部署与 GitLab CI 集成实战
  • 【亲测有效】如何清空但不删除GitHub仓库中的所有文件(main分支)
  • K8S扩缩容及滚动更新和回滚
  • 昆仑万维一季度营收增长46% AI业务成新增长点
  • 集成管理工具Gitlab
  • 软考高级系统架构设计师备考分享:操作系统核心知识点整理
  • Java设计模式之原型模式详解:从入门到精通
  • 纯Java实现反向传播算法:零依赖神经网络实战
  • Docker常见疑难杂症解决指南:深入解析与实战解决方案
  • 【阿里云免费领取域名以及ssl证书,通过Nginx反向代理web服务】
  • STM32TIM定时中断(6)
  • 数据统计的意义:钱包余额变动
  • 区块链详解
  • leetcode 383. Ransom Note
  • 碧桂园境外债务重组:相当于现有公众票据本金额逾50%的持有人已加入协议
  • “浦东时刻”在京展出:沉浸式体验海派风情
  • 总奖池超百万!第五届七猫现实题材征文大赛颁奖在即
  • 九家企业与上海静安集中签约,投资额超10亿元
  • 光大华夏:近代中国私立大学遥不可及的梦想
  • 化学家、台湾地区“中研院”原学术副院长陈长谦逝世