完整的训练与测试套路 小土堆pytorch记录
总算是看完入门了,流程如下
第一部分:训练代码 (train.py
) 详细解析
代码如下
import os
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time# 设备设置(这部分需要被导入,所以放在if外面)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.set_device(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {torch.cuda.get_device_name(device)}")# 准备数据集(这部分也可以被导入)
train_data = torchvision.datasets.CIFAR10(root="./P_10_dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./P_10_dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 利用dataloader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class Test_net(nn.Module):def __init__(self):super(Test_net, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 将训练代码放在这个if块中
if __name__ == "__main__":print("训练集长度:{}".format(train_data_size))print("测试集长度:{}".format(test_data_size))test_net = Test_net()test_net = test_net.to(device)loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)learning_rate = 0.01optimizer = torch.optim.SGD(test_net.parameters(), lr=learning_rate)total_train_step = 0total_test_step = 0epoch = 50writer = SummaryWriter("log_train")start_time = time.time()for i in range(epoch):test_net.train()print("------第{}轮训练开始------".format(i + 1))for data in train_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test_net(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:")print(end_time - start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)test_net.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test_net(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的Accuracy:{}".format(total_accuracy / test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1torch.save(test_net, "test_net{}.pth".format(i))print("模型已保存!")writer.close()
第二部分:测试/推理代码 (test.py
) 详细解析
代码如下
import torch
import torchvision
from PIL import Image
from torch import nn# 从训练脚本中导入设备配置(现在不会执行训练循环了)
from train_gup_2 import deviceprint(f"使用设备: {device}")# 定义CIFAR-10的类别名称
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']# 图像预处理
image_path = "img/horse.png"
image = Image.open(image_path)
image = image.convert('RGB')transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()
])image = transform(image)
print(f"图像形状: {image.shape}")# 创建网络模型(结构必须与训练时完全一致)
class Test_net(nn.Module):def __init__(self):super(Test_net, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 加载模型
model_path = "test_net49.pth"
model = torch.load("test_net49.pth", map_location=device, weights_only=False)
model = model.to(device)
model.eval()
print(f"已加载模型: {model_path}")# 准备输入数据
image = image.unsqueeze(0)
image = image.to(device)# 进行预测
with torch.no_grad():output = model(image)# 处理预测结果
probabilities = torch.nn.functional.softmax(output[0], dim=0)
predicted_idx = output.argmax(1).item()
predicted_class = class_names[predicted_idx]
confidence = probabilities[predicted_idx].item()# 输出详细预测结果
print("\n===== 预测结果 =====")
print(f"预测类别索引: {predicted_idx}")
print(f"预测类别名称: {predicted_class}")
print(f"预测置信度: {confidence:.4f}")# 输出top-3预测结果
top3_prob, top3_idx = torch.topk(probabilities, 3)
print("\nTop-3 预测:")
for i, (idx, prob) in enumerate(zip(top3_idx, top3_prob)):print(f"{i + 1}. {class_names[idx.item()]}: {prob.item():.4f}")
运行结果
总结与拓展
设备一致性 (Device Consistency):
黄金法则:模型、损失函数、输入数据必须在同一个设备上(CPU或GPU)。
务必使用
.to(device)
来移动模型和张量。
模式切换 (Train/Eval Mode):
model.train()
:训练模式,启用Dropout、BatchNorm的更新。model.eval()
:评估模式,关闭Dropout、固定BatchNorm的统计量。忘记切换是常见错误,会导致评估结果不一致。
梯度管理 (Gradient Management):
每次反向传播前必须
optimizer.zero_grad()
,否则梯度会累积。with torch.no_grad():
上下文管理器用于禁用梯度计算,节省测试时的内存和计算。
数据预处理一致性:
测试/推理时的预处理必须与训练时完全一致(相同的Resize、归一化方式等),否则性能会急剧下降。
模型保存与加载:
保存整个模型 (
torch.save(model, path)
):方便但文件大,对代码版本有依赖。保存状态字典 (
torch.save(model.state_dict(), path)
):推荐方式,灵活轻便。加载时需要先实例化模型结构。加载时注意
map_location
参数,如果是在CPU上加载GPU训练的模型,需要map_location=torch.device('cpu')
。
张量维度 (Tensor Dimensions):
模型的输入通常是4D:
(Batch_size, Channels, Height, Width)
。单张图片推理时需要用
unsqueeze(0)
或reshape(1, C, H, W)
来添加批次维度。
deepseek给出的下一步建议学习的东西
下一步学习路径:
深入理解模型结构:学习ResNet、VGG、Transformer等经典架构,理解其设计思想。
数据增强 (Data Augmentation):在
transforms.Compose
中添加如随机裁剪、翻转、颜色抖动等,这是提升模型泛化能力的关键。学习更先进的优化器:如Adam、AdamW,它们通常比SGD收敛更快。
学习使用学习率调度器 (Scheduler):如
StepLR
,CosineAnnealingLR
,动态调整学习率以获得更好性能。尝试自己的数据集:不再使用CIFAR10等内置数据集,学习如何编写自定义的
Dataset
类来加载自己的数据。
调试技巧:
多用
print(x.shape)
来检查张量维度,确保层与层之间匹配。使用TensorBoard或Matplotlib可视化训练过程、特征图,直观理解模型行为。
遇到NaN损失(梯度爆炸):尝试减小学习率、添加梯度裁剪 (
torch.nn.utils.clip_grad_norm_
)。
项目实践:
Kaggle:参加入门级的图像分类比赛(如Dogs vs Cats)。
复现论文:找一篇简单的CV领域论文(如2015年左右的),尝试复现其基本结果。
部署尝试:学习使用ONNX、TorchScript或简单的Flask/FastAPI将模型封装成API,这是工业应用的重要一环。
理论基础:
在实践的同时,务必补充《深度学习》(花书)、《统计学习方法》等理论基础,理解反向传播、优化、正则化等概念的数学本质。
大概就是可以试着去GitHub找一些代码复现一下,以及可以继续看李沐老师的深度学习了