使用 PyTorch 构建并训练 CNN 模型
卷积神经网络(CNN)在计算机视觉领域占据核心地位,尤其在图像分类任务中表现出色。CIFAR-10 数据集是入门计算机视觉的经典数据集,包含10 类(飞机、汽车、鸟等)、分辨率为32×32
的彩色图像,非常适合验证 CNN 模型的效果。
本文将基于 PyTorch 框架,从数据预处理、模型定义、训练过程到性能评估,完整演示如何构建并训练 CNN 模型完成 CIFAR-10 分类任务。
一、环境准备与库导入
首先导入所需的 Python 库,确保 PyTorch、TorchVision 等已安装(可通过pip install torch torchvision
安装)。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from collections import Counter
torch
:PyTorch 核心库,提供张量运算、自动微分等功能。torch.nn
:神经网络模块库,包含卷积、池化、全连接层等。torch.optim
:优化器库,如 SGD、Adam 等。torchvision
:计算机视觉工具库,提供数据集、图像变换等功能。numpy
:数值计算库,辅助数据处理。
二、数据预处理与加载
为了提升模型泛化能力,训练集需做数据增强(随机裁剪、水平翻转);测试集保持 “干净”,仅做标准化等基础变换。
2.1 数据变换定义
# 设备自动检测:优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 训练集变换:数据增强 + 标准化
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), # 随机裁剪(填充4像素后裁为32×32)transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转为Tensor(范围缩至[0,1])transforms.Normalize( # 标准化(减均值、除标准差)mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])# 测试集变换:仅标准化(无数据增强)
transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])
2.2 数据集与数据加载器
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', # 数据保存路径train=True, # 训练集download=False, # 若本地已有,设为Falsetransform=transform_train
)
trainloader = DataLoader(trainset, batch_size=128, # 批次大小shuffle=True, # 打乱数据num_workers=2 # 多线程加载数据
)testset = torchvision.datasets.CIFAR10(root='./data', train=False, # 测试集download=False, transform=transform_test
)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2
)# CIFAR-10的10个类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
数据增强通过增加训练样本多样性,减少模型过拟合;DataLoader
负责按批次高效加载数据,提升训练效率。
三、CNN 模型定义
CNN 的核心是卷积层(提取局部特征)、池化层(缩小特征图并保留关键信息)、全连接层(分类决策)。下面定义两个模型:自定义 CNN(Net
)和经典 LeNet。
3.1 自定义 CNN 模型(Net
)
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 第一层卷积:输入3通道(彩色),输出16通道,卷积核5×5self.conv1 = nn.Conv2d(3, 16, 5)# 最大池化:核2×2,步长2self.pool1 = nn.MaxPool2d(2, 2)# 第二层卷积:输入16通道,输出36通道,卷积核5×5self.conv2 = nn.Conv2d(16, 36, 5)# 第二层池化self.pool2 = nn.MaxPool2d(2, 2)# 自适应平均池化:将特征图缩为1×1(通道数保留)self.aap = nn.AdaptiveAvgPool2d(1)# 全连接层:输入36(通道数),输出10(类别数)self.fc3 = nn.Linear(36, 10)def forward(self, x):# 卷积 → ReLU激活 → 池化x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))# 自适应池化(统一特征图尺寸)x = self.aap(x)# 展平(保留批次维度,其余维度合并)x = x.view(x.shape[0], -1)# 全连接层输出类别概率x = self.fc3(x)return x
3.2 经典 LeNet 模型
LeNet 是 CNN 的经典雏形,结构更简洁,适合理解 CNN 基本流程:
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,核5×5self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 特征图经池化后尺寸为5×5self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 卷积 → ReLU → 池化out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)# 展平为一维向量out = out.view(out.size(0), -1)# 全连接层 → ReLUout = F.relu(self.fc1(out))out = F.relu(self.fc2(out))# 输出层(无激活,配合交叉熵损失)out = self.fc3(out)return out
四、模型训练过程
训练的核心是 **“前向传播计算损失 → 反向传播求梯度 → 优化器更新参数”** 的循环。
4.1 初始化模型、损失函数与优化器
# 初始化模型并放到设备(GPU/CPU)
net = Net().to(device)
# 交叉熵损失(适合多分类任务)
criterion = nn.CrossEntropyLoss()
# SGD优化器(带动量,加速收敛)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4.2 训练循环
epochs = 10 # 训练轮数
for epoch in range(epochs):running_loss = 0.0 # 累计损失# 遍历训练数据加载器for i, data in enumerate(trainloader, 0):# 获取输入和标签,并转移到设备inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 1. 梯度清零(防止累积)optimizer.zero_grad()# 2. 前向传播:模型预测outputs = net(inputs)# 计算损失loss = criterion(outputs, labels)# 3. 反向传播:计算梯度loss.backward()# 4. 优化器更新参数optimizer.step()# 统计损失running_loss += loss.item()# 每2000个mini-batch打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('Finished Training')
训练过程中,running_loss
用于监控损失变化:若损失持续下降,说明模型在学习;若损失震荡或上升,可能是学习率过大或模型过拟合。
五、模型结构可视化(可选)
为了更清晰地了解模型各层的输入、输出和参数,我们实现类似 Keras model.summary()
的功能:
import collectionsdef params_summary(input_size, model):def register_hook(module):def hook(module, input, output):# 获取模块类名class_name = str(module.__class__).split('.')[-1].split("'")[0]module_idx = len(summary)m_key = f'{class_name}-{module_idx+1}'summary[m_key] = collections.OrderedDict()# 记录输入/输出形状(批次维度设为-1,代表任意大小)summary[m_key]['input_shape'] = list(input[0].size())summary[m_key]['input_shape'][0] = -1summary[m_key]['output_shape'] = list(output.size())summary[m_key]['output_shape'][0] = -1# 统计参数数量params = 0if hasattr(module, 'weight') and hasattr(module.weight, 'size'):params += torch.prod(torch.LongTensor(list(module.weight.size())))summary[m_key]['trainable'] = module.weight.requires_gradif hasattr(module, 'bias') and hasattr(module.bias, 'size'):params += torch.prod(torch.LongTensor(list(module.bias.size())))summary[m_key]['nb_params'] = params# 仅对非容器类模块注册hookif not isinstance(module, nn.Sequential) and \not isinstance(module, nn.ModuleList) and \not (module == model):hooks.append(module.register_forward_hook(hook))# 生成随机输入用于前向传播if isinstance(input_size[0], (list, tuple)):x = [torch.rand(1, *in_size) for in_size in input_size]else:x = torch.rand(1, *input_size)summary = collections.OrderedDict()hooks = []# 注册所有模块的hookmodel.apply(register_hook)# 执行前向传播(触发hook)model(x)# 移除hook(避免影响后续操作)for h in hooks:h.remove()return summary# 查看Net的结构摘要
summary = params_summary((3, 32, 32), Net())
for layer, info in summary.items():print(f"Layer: {layer}")print(f" Input Shape: {info['input_shape']}")print(f" Output Shape: {info['output_shape']}")print(f" Params: {info['nb_params']}")print(f" Trainable: {info.get('trainable', False)}")print("-" * 50)
运行后,会打印每一层的输入形状、输出形状、参数数量和是否可训练,帮助我们验证模型结构是否符合预期。
六、模型评估(测试集性能)
训练完成后,在测试集上评估模型准确率(泛化能力):
correct = 0
total = 0
# 测试时无需计算梯度,加快速度
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 获取最大概率对应的类别_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'模型在10000张测试图像上的准确率: {100 * correct / total:.2f}%')
还可以进一步分析每个类别的准确率:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()# 假设batch_size为100,可根据实际调整循环次数for i in range(100):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'{classes[i]}的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')
七、总结与优化方向
本文完整演示了 “数据预处理→模型定义→训练→评估” 的深度学习流程。若想进一步提升性能,可尝试:
- 网络结构优化:增加卷积层、加入 BatchNorm 层、使用残差连接(ResNet 思想)。
- 超参数调整:增大学习率(配合学习率衰减)、调整批次大小、增加训练轮数。
- 优化器升级:改用 Adam 优化器(自适应学习率,收敛更快)。
- 正则化手段:加入 Dropout 层、权重衰减(L2 正则),缓解过拟合。
通过不断迭代优化,CIFAR-10 的分类准确率可逐步提升至 80% 以上~