代码案例实践
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import collections# ----------------------------
# 1. 超参数定义
# ----------------------------
BATCH_SIZE = 128
EPOCHS = 10
LR = 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# ----------------------------
# 2. 模型定义(多种CNN架构)
# ----------------------------
class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(36 * 6 * 6, 128) # 适配CIFAR10输入尺寸self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 36 * 6 * 6) # 展平特征图x = F.relu(self.fc1(x))x = self.fc2(x)return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 36, 5)self.pool2 = nn.MaxPool2d(2, 2)self.aap = nn.AdaptiveAvgPool2d(1) # 自适应平均池化self.fc3 = nn.Linear(36, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x) # 输出形状:[batch, 36, 1, 1]x = x.view(x.shape[0], -1) # 展平为[batch, 36]x = self.fc3(x)return xclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 适配CIFAR10self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1) # 展平x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# ----------------------------
# 3. 模型参数统计工具
# ----------------------------
def params_summary(input_size, model):"""统计模型各层输入输出形状及参数数量"""def register_hook(module):def hook(module, input, output):class_name = str(module.__class__).split('.')[-1].split("'")[0]module_idx = len(summary)m_key = f"{class_name}-{module_idx + 1}"summary[m_key] = collections.OrderedDict()summary[m_key]['input_shape'] = list(input[0].size())summary[m_key]['input_shape'][0] = -1 # 批量大小设为-1(通用)summary[m_key]['output_shape'] = list(output.size())summary[m_key]['output_shape'][0] = -1params = 0if hasattr(module, 'weight'):params += torch.prod(torch.LongTensor(list(module.weight.size())))summary[m_key]['trainable'] = module.weight.requires_gradif hasattr(module, 'bias'):params += torch.prod(torch.LongTensor(list(module.bias.size())))summary[m_key]['nb_params'] = params# 排除容器类模块(如Sequential)if not isinstance(module, nn.Sequential) and \not isinstance(module, nn.ModuleList) and \module != model:hooks.append(module.register_forward_hook(hook))# 生成随机输入用于测试if isinstance(input_size[0], (list, tuple)):x = [torch.rand(1, *in_size) for in_size in input_size]else:x = torch.rand(1, *input_size)summary = collections.OrderedDict()hooks = []model.apply(register_hook) # 注册钩子model(x) # 前向传播触发钩子for h in hooks:h.remove() # 移除钩子return summary# ----------------------------
# 4. 数据准备
# ----------------------------
def prepare_data():"""加载并预处理CIFAR10数据集"""transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), # 数据增强transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])# 加载数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)return trainloader, testloader# ----------------------------
# 5. 模型训练与评估
# ----------------------------
def train_model(net, trainloader, criterion, optimizer):"""训练模型"""net.train() # 训练模式for epoch in range(EPOCHS):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)# 梯度清零optimizer.zero_grad()# 前向传播 + 计算损失 + 反向传播 + 参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印损失running_loss += loss.item()if i % 100 == 99: # 每100个batch打印一次print(f'[Epoch {epoch+1}, Batch {i+1}] loss: {running_loss/100:.3f}')running_loss = 0.0print('Finished Training')def evaluate_model(net, testloader):"""评估模型在测试集上的准确率"""net.eval() # 评估模式correct = 0total = 0with torch.no_grad(): # 关闭梯度计算for data in testloader:images, labels = data[0].to(DEVICE), data[1].to(DEVICE)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on test set: {100 * correct / total:.2f}%')# ----------------------------
# 6. 主函数
# ----------------------------
if __name__ == '__main__':# 准备数据trainloader, testloader = prepare_data()# 初始化模型、损失函数、优化器net = LeNet().to(DEVICE) # 可替换为CNNNet()或Net()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)# 打印模型参数 summaryinput_size = (3, 32, 32) # CIFAR10图像尺寸:3通道,32x32summary = params_summary(input_size, net)print("Model Summary:")for layer, info in summary.items():print(f"{layer}: {info}")# 训练模型train_model(net, trainloader, criterion, optimizer)# 评估模型evaluate_model(net, testloader)
今天主要实现了一个基于PyTorch的CIFAR-10图像分类系统,包含三个CNN模型架构(CNNNet、Net和LeNet)。系统首先加载并进行数据增强处理,然后训练指定模型10个epoch,使用SGD优化器和交叉熵损失函数。代码还提供了模型参数统计工具,可输出各层输入输出形状及参数数量。训练完成后,系统评估模型在测试集上的准确率。整个流程支持GPU加速,并包含详细的训练过程日志输出。