PyTorch 构建神经网络
组件 | 作用 |
---|---|
层(Layer) | 网络基本单元,如卷积层(Conv2d)、线性层(Linear),负责张量数据变换 |
模型(Model) | 由多层按逻辑组合而成的整体,实现从输入到输出的映射 |
损失函数 | 衡量预测值与真实值的差距,如交叉熵损失(CrossEntropyLoss),是参数优化的目标 |
优化器 | 通过反向传播更新模型参数以最小化损失,如 Adam、SGD |
PyTorch 模型构建
1. nn.Module:可训练参数的 “管理者”
特点:所有带可学习参数的层(如 Conv2d、Linear)均继承自 nn.Module,能自动追踪参数,支持与模型容器结合使用。
用法:自定义模型需继承 nn.Module,在__init__
中定义层,在forward
中实现前向传播逻辑。
示例:定义一个简单线性层模块
python运行
import torch.nn as nn
class SimpleLinear(nn.Module):def __init__(self, in_dim, out_dim):super().__init__()self.linear = nn.Linear(in_dim, out_dim) # 可学习参数由nn.Module管理def forward(self, x):return self.linear(x)
nn.functional:纯函数式工具
特点:无参数的 “纯函数” 集合,如激活函数(ReLU)、池化(max_pool2d),需手动传入参数(若有),无法与模型容器直接结合。
注意:dropout 操作若用 nn.functional 实现,需手动区分训练 / 测试模式;而 nn.Dropout(继承自 nn.Module)可通过model.eval()
自动切换状态。
三种模型构建方法
1. 直接继承 nn.Module:最灵活
适用于复杂网络结构,需手动定义每一层的连接逻辑。例如构建含批归一化的全连接网络:
python运行
import torch.nn.functional as F
class FCModel(nn.Module):def __init__(self, in_dim=784, n_hidden=300, out_dim=10):super().__init__()self.flatten = nn.Flatten() # 展平28*28图像self.linear1 = nn.Linear(in_dim, n_hidden)self.bn1 = nn.BatchNorm1d(n_hidden) # 批归一化def forward(self, x):x = self.flatten(x)x = F.relu(self.bn1(self.linear1(x))) # 前向传播逻辑return x
2. nn.Sequential:按序堆叠,快速高效
适合层与层按顺序连接的简单网络,支持三种定义方式:
可变参数:直接传入层实例,无需命名
python运行
seq = nn.Sequential(nn.Flatten(), nn.Linear(784, 300), nn.ReLU())
- add_module:为每层指定名称,便于后续查看
python运行
seq = nn.Sequential() seq.add_module("flatten", nn.Flatten()) seq.add_module("linear1", nn.Linear(784, 300))
- OrderedDict:用有序字典定义,兼顾顺序与命名
python运行
from collections import OrderedDict seq = nn.Sequential(OrderedDict([("flatten", nn.Flatten()),("linear1", nn.Linear(784, 300)) ]))
nn.Sequential 封装残差块:
python运行
class ResBlockWrapper(nn.Module):def __init__(self):super().__init__()# 用nn.Sequential封装残差块内的层self.res_block = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64))def forward(self, x):return F.relu(x + self.res_block(x))
从自定义模块到 ResNet18
1. 定义两种残差块
ResNet18 包含两种残差块,分别处理 “维度不变” 和 “维度下采样” 场景:
python运行
class RestNetBasicBlock(nn.Module):# 基础残差块:输入输出维度一致,无需额外调整def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))return F.relu(x + out) # 残差连接class RestNetDownBlock(nn.Module):# 下采样残差块:用1×1卷积调整输入维度,适配残差连接def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)# 1×1卷积调整输入通道和分辨率self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride[0]),nn.BatchNorm2d(out_channels))def forward(self, x):extra_x = self.extra(x) # 维度调整out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))return F.relu(extra_x + out)
2. 组合成 ResNet18 架构
基于两种残差块,按 “初始卷积→残差层→全局池化→全连接” 的顺序构建 ResNet18,适配 3 通道的人脸图像:
python运行
class RestNet18(nn.Module):def __init__(self, num_classes): # num_classes:人脸类别数super().__init__()# 初始层:降维+下采样self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)# 4个残差层:2个基础块+2个下采样块组合self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2,1]), RestNetBasicBlock(128, 128, 1))self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2,1]), RestNetBasicBlock(256, 256, 1))self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2,1]), RestNetBasicBlock(512, 512, 1))# 分类头:全局平均池化+全连接self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512, num_classes)def forward(self, x):# 前向传播:按层顺序执行x = self.bn1(self.conv1(x))x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1) # 展平为一维向量return self.fc(x)
总结
PyTorch 模型构建的核心在于 “灵活组合”:通过 nn.Module 管理可训练参数,用 nn.Sequential 等容器简化层连接,结合自定义模块(如残差块)可实现复杂架构。从基础全连接网络到 ResNet18