神经网络工具箱
1. 继承nn.Module基类
方式:自定义类继承nn.Module,在__init__
中定义层,在forward
中定义前向传播
优点:灵活性最高,可定制复杂结构
2. 使用nn.Sequential顺序构建
三种实现方式:
可变参数:nn.Sequential(layer1, layer2, ...)
add_module方法:可指定每层名称
OrderedDict方法:显式命名每层
优点:代码简洁,适合顺序结构
3. 继承+模型容器组合
容器类型:
nn.Sequential
:顺序容器
nn.ModuleList
:列表式容器(可迭代)
nn.ModuleDict
:字典式容器(可按名称访问)
优点:兼顾灵活性和代码组织
4. 基本残差块
结构:输入 → 卷积 → 激活 → 卷积 → 与输入相加 → ReLU
特点:恒等映射,要求输入输出形状一致
5. 带降维残差块
结构:添加1×1卷积调整通道数和分辨率
作用:使输入输出形状匹配,支持维度变化
6. ResNet18组合
组成:交替使用基本残差块和降维残差块
意义:现代经典网络结构,解决梯度消失问题
7.六步训练法:
加载预处理数据集 - 数据准备
定义损失函数 - 如CrossEntropyLoss
定义优化方法 - 如SGD、Adam
循环训练模型 - 前向传播、反向传播、参数更新
循环测试/验证模型 - 评估性能
可视化结果 - 分析训练过程
特性 | nn.Module | nn.functional |
---|---|---|
参数管理 | 自动管理 | 手动传入 |
与Sequential结合 | 支持 | 不支持 |
状态管理 | 自动切换训练/测试 | 无状态管理 |
代码复用 | 高 | 低 |
适用场景 | 有参数层 | 无参数操作 |