PyTorch 神经网络工具箱核心知识梳理
一、神经网络核心组件
神经网络的构建与训练依赖四大核心组件,各组件功能明确且相互配合,构成完整的模型运行体系:
组件 | 核心功能 |
---|---|
层(Layer) | 神经网络的基本结构单元,负责将输入张量通过数据变换(如卷积、线性运算)转换为输出张量。 |
模型(Model) | 由多个层按特定逻辑组合而成的网络结构,是实现任务(分类、回归等)的核心载体。 |
损失函数 | 定义参数学习的目标函数,量化模型预测值(Y')与真实值(Y)的差异,是参数优化的依据。 |
优化器 | 采用特定算法(如 SGD、Adam)最小化损失函数,实现模型参数(权重等)的迭代更新。 |
二、构建神经网络的主要工具
PyTorch 提供nn.Module
和nn.functional
两大核心工具,二者在用法和场景上存在显著差异:
1. 核心工具对比
维度 | nn.Module | nn.functional |
---|---|---|
本质 | 面向对象的模块,继承自nn.Module 基类 | 纯函数式接口 |
典型应用 | 卷积层(nn.Conv2d )、全连接层(nn.Linear )、dropout 层(nn.Dropout ) | 激活函数(F.relu )、池化层(F.max_pool2d )、损失计算(F.cross_entropy ) |
用法规范 | 需先实例化并传入参数,再以函数调用方式传入数据(如layer = nn.Linear(10,2); layer(x) ) | 直接调用函数并传入数据及参数(如F.linear(x, weight, bias) ) |
参数管理 | 自动定义和管理weight 、bias 等可学习参数 | 需手动定义和传入weight 、bias ,不利于代码复用 |
容器兼容性 | 可与nn.Sequential 等模型容器无缝结合 | 无法与nn.Sequential 等容器结合 |
状态转换(如 dropout) | 调用model.eval() 后自动切换训练 / 测试状态 | 需手动控制状态,无自动转换功能 |
三、模型构建的三种核心方式
PyTorch 支持多种模型构建方式,可根据模型复杂度和模块化需求选择:
1. 继承nn.Module
基类构建(灵活度最高)
适用于复杂模型,需手动定义层结构和正向传播逻辑,核心步骤包括:
- 定义模型类并继承
nn.Module
; - 在
__init__
方法中初始化各层组件(如全连接层、批归一化层); - 实现
forward
方法定义数据流向(正向传播过程)。
代码示例片段:
python
运行
import torch
from torch import nn
import torch.nn.functional as Fclass Model_Seq(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_Seq, self).__init__()self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_dim, n_hidden_1)self.bn1 = nn.BatchNorm1d(n_hidden_1)self.linear2 = nn.Linear(n_hidden_1, n_hidden_2)self.bn2 = nn.BatchNorm1d(n_hidden_2)self.out = nn.Linear(n_hidden_2, out_dim)def forward(self, x):x = self.flatten(x)x = F.relu(self.bn1(self.linear1(x)))x = F.relu(self.bn2(self.linear2(x)))x = F.softmax(self.out(x), dim=1)return x
2. 使用nn.Sequential
按层顺序构建(简洁高效)
适用于线性堆叠的简单模型,无需手动定义forward
方法,支持三种实现方式:
- 可变参数方式:直接传入层实例,无法指定层名称;
python
运行
Seq_arg = nn.Sequential(nn.Flatten(),nn.Linear(in_dim, n_hidden_1),nn.ReLU() )
add_module
方法:通过add_module("层名称", 层实例)
指定层名称;OrderedDict
方法:通过有序字典传入(键为层名称,值为层实例),保证层顺序。
3. 继承nn.Module
结合模型容器构建(模块化兼顾灵活)
将模型拆分为多个子模块,通过nn.Sequential
、nn.ModuleList
、nn.ModuleDict
等容器管理,平衡模块化与灵活性:
nn.Sequential
容器:按顺序封装子模块,适合固定流程的子网络;nn.ModuleList
容器:以列表形式存储层,支持索引访问,适合动态调整层数量;nn.ModuleDict
容器:以字典形式存储层,通过键名访问,适合灵活控制层顺序。
nn.ModuleDict
实现示例片段:
python
运行
class Model_dict(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_dict, self).__init__()self.layers_dict = nn.ModuleDict({"flatten": nn.Flatten(),"linear1": nn.Linear(in_dim, n_hidden_1),"relu": nn.ReLU(),"out": nn.Linear(n_hidden_2, out_dim)})def forward(self, x):layers = ["flatten", "linear1", "relu", "out"] # 手动定义执行顺序for layer in layers:x = self.layers_dict[layer](x)return x
四、自定义网络模块(以 ResNet 为例)
对于复杂网络结构(如残差网络 ResNet),可通过自定义模块实现复用,核心包括两种残差块:
- 基础残差块(
RestNetBasicBlock
):输入与输出形状一致,直接将输入与卷积输出相加后激活; - 下采样残差块(
RestNetDownBlock
):通过 1×1 卷积调整输入通道数和分辨率,确保输入与输出可相加。
将两种模块组合可构建 ResNet18 等经典网络,示例如下:
python
运行
class RestNet18(nn.Module):def __init__(self):super(RestNet18, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]), RestNetBasicBlock(128, 128, 1))# 后续layer3、layer4及全连接层省略...def forward(self, x):# 正向传播逻辑省略...return out
五、模型训练流程
完整的模型训练需遵循固定流程,确保模型有效学习:
- 加载预处理数据集:准备训练 / 测试数据,进行标准化、批处理等预处理;
- 定义损失函数:根据任务选择(如分类用
nn.CrossEntropyLoss
,回归用nn.MSELoss
); - 定义优化方法:选择优化器(如
torch.optim.SGD
、torch.optim.Adam
)并配置学习率; - 循环训练模型:迭代输入数据,执行正向传播→计算损失→反向传播(
loss.backward()
)→参数更新(optimizer.step()
); - 循环测试或验证:定期在验证集上评估模型性能,避免过拟合;
- 可视化结果:绘制损失曲线、准确率曲线等,分析模型训练效果。