PyTorch构建自定义模型
PyTorch 提供了灵活的方式来构建自定义神经网络模型。下面我将详细介绍从基础到高级的自定义模型构建方法,包含实际代码示例和最佳实践。
一、基础模型构建
1. 继承 nn.Module 基类
所有自定义模型都应该继承 torch.nn.Module
类,并实现两个基本方法:
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__() # 必须调用父类初始化
# 在这里定义网络层
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 50, 5)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
# 定义前向传播逻辑
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
2. 模型使用方式
model = MyModel()
output = model(input_tensor) # 自动调用forward方法
loss = criterion(output, target)
loss.backward()
二、中级构建技巧
1. 使用 nn.Sequential
nn.Sequential
是一种用于快速构建顺序神经网络的容器类,适用于模块按线性顺序排列的模型。
class MySequentialModel(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 8 * 8, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
2. 参数初始化
def initialize_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
model.apply(initialize_weights) # 递归应用初始化函数
三、高级构建模式
1. 残差连接 (ResNet风格)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
2. 自定义层
class MyCustomLayer(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
self.bias = nn.Parameter(torch.randn(output_dim))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
四、模型保存与加载
1. 保存整个模型
torch.save(model, 'model.pth') # 保存
model = torch.load('model.pth') # 加载
2. 保存状态字典 (推荐)
torch.save(model.state_dict(), 'model_state.pth') # 保存
model.load_state_dict(torch.load('model_state.pth')) # 加载
五、模型部署准备
1. 模型导出为TorchScript
scripted_model = torch.jit.script(model) # 或 torch.jit.trace
scripted_model.save('model_scripted.pt')
2. ONNX格式导出
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx",
input_names=["input"], output_names=["output"])
六、完整示例:自定义CNN分类器
import torch
from torch import nn
from torch.nn import functional as F
class CustomCNN(nn.Module):
"""自定义CNN图像分类器
Args:
num_classes (int): 输出类别数
dropout_prob (float): dropout概率,默认0.5
"""
def __init__(self, num_classes=10, dropout_prob=0.5):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(p=dropout_prob),
nn.Linear(128 * 6 * 6, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout_prob),
nn.Linear(512, num_classes)
)
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播
Args:
x (torch.Tensor): 输入张量,形状为[B, C, H, W]
Returns:
torch.Tensor: 输出logits,形状为[B, num_classes]
"""
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
七、注意事项
-
输入输出维度匹配
- 需确保相邻模块的输入/输出维度兼容。例如,卷积层后接全连接层时需通过
Flatten
或自适应池化调整维度。
- 需确保相邻模块的输入/输出维度兼容。例如,卷积层后接全连接层时需通过
-
调试与验证
- 可通过模拟输入数据验证模型结构,如:
input = torch.ones(64, 3, 32, 32) # 模拟 batch_size=64 的输入 output = model(input) print(output.shape) # 检查输出形状是否符合预期
- 可通过模拟输入数据验证模型结构,如: