当前位置: 首页 > news >正文

完整的训练与测试套路 小土堆pytorch记录

总算是看完入门了,流程如下

第一部分:训练代码 (train.py) 详细解析

代码如下

import os
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time# 设备设置(这部分需要被导入,所以放在if外面)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.set_device(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {torch.cuda.get_device_name(device)}")# 准备数据集(这部分也可以被导入)
train_data = torchvision.datasets.CIFAR10(root="./P_10_dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./P_10_dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 利用dataloader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class Test_net(nn.Module):def __init__(self):super(Test_net, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 将训练代码放在这个if块中
if __name__ == "__main__":print("训练集长度:{}".format(train_data_size))print("测试集长度:{}".format(test_data_size))test_net = Test_net()test_net = test_net.to(device)loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)learning_rate = 0.01optimizer = torch.optim.SGD(test_net.parameters(), lr=learning_rate)total_train_step = 0total_test_step = 0epoch = 50writer = SummaryWriter("log_train")start_time = time.time()for i in range(epoch):test_net.train()print("------第{}轮训练开始------".format(i + 1))for data in train_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test_net(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:")print(end_time - start_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)test_net.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test_net(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的Accuracy:{}".format(total_accuracy / test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1torch.save(test_net, "test_net{}.pth".format(i))print("模型已保存!")writer.close()

第二部分:测试/推理代码 (test.py) 详细解析

代码如下

import torch
import torchvision
from PIL import Image
from torch import nn# 从训练脚本中导入设备配置(现在不会执行训练循环了)
from train_gup_2 import deviceprint(f"使用设备: {device}")# 定义CIFAR-10的类别名称
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']# 图像预处理
image_path = "img/horse.png"
image = Image.open(image_path)
image = image.convert('RGB')transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()
])image = transform(image)
print(f"图像形状: {image.shape}")# 创建网络模型(结构必须与训练时完全一致)
class Test_net(nn.Module):def __init__(self):super(Test_net, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x# 加载模型
model_path = "test_net49.pth"
model = torch.load("test_net49.pth", map_location=device, weights_only=False)
model = model.to(device)
model.eval()
print(f"已加载模型: {model_path}")# 准备输入数据
image = image.unsqueeze(0)
image = image.to(device)# 进行预测
with torch.no_grad():output = model(image)# 处理预测结果
probabilities = torch.nn.functional.softmax(output[0], dim=0)
predicted_idx = output.argmax(1).item()
predicted_class = class_names[predicted_idx]
confidence = probabilities[predicted_idx].item()# 输出详细预测结果
print("\n===== 预测结果 =====")
print(f"预测类别索引: {predicted_idx}")
print(f"预测类别名称: {predicted_class}")
print(f"预测置信度: {confidence:.4f}")# 输出top-3预测结果
top3_prob, top3_idx = torch.topk(probabilities, 3)
print("\nTop-3 预测:")
for i, (idx, prob) in enumerate(zip(top3_idx, top3_prob)):print(f"{i + 1}. {class_names[idx.item()]}: {prob.item():.4f}")

运行结果

总结与拓展

  1. 设备一致性 (Device Consistency):

    • 黄金法则:模型、损失函数、输入数据必须在同一个设备上(CPU或GPU)。

    • 务必使用 .to(device) 来移动模型和张量。

  2. 模式切换 (Train/Eval Mode):

    • model.train():训练模式,启用Dropout、BatchNorm的更新。

    • model.eval():评估模式,关闭Dropout、固定BatchNorm的统计量。

    • 忘记切换是常见错误,会导致评估结果不一致。

  3. 梯度管理 (Gradient Management):

    • 每次反向传播前必须 optimizer.zero_grad(),否则梯度会累积。

    • with torch.no_grad(): 上下文管理器用于禁用梯度计算,节省测试时的内存和计算。

  4. 数据预处理一致性:

    • 测试/推理时的预处理必须与训练时完全一致(相同的Resize、归一化方式等),否则性能会急剧下降。

  5. 模型保存与加载:

    • 保存整个模型 (torch.save(model, path)):方便但文件大,对代码版本有依赖。

    • 保存状态字典 (torch.save(model.state_dict(), path)):推荐方式,灵活轻便。加载时需要先实例化模型结构。

    • 加载时注意 map_location 参数,如果是在CPU上加载GPU训练的模型,需要 map_location=torch.device('cpu')

  6. 张量维度 (Tensor Dimensions):

    • 模型的输入通常是4D:(Batch_size, Channels, Height, Width)

    • 单张图片推理时需要用 unsqueeze(0) 或 reshape(1, C, H, W) 来添加批次维度。

deepseek给出的下一步建议学习的东西

  1. 下一步学习路径:

    • 深入理解模型结构:学习ResNet、VGG、Transformer等经典架构,理解其设计思想。

    • 数据增强 (Data Augmentation):在 transforms.Compose 中添加如随机裁剪、翻转、颜色抖动等,这是提升模型泛化能力的关键。

    • 学习更先进的优化器:如Adam、AdamW,它们通常比SGD收敛更快。

    • 学习使用学习率调度器 (Scheduler):如 StepLRCosineAnnealingLR,动态调整学习率以获得更好性能。

    • 尝试自己的数据集:不再使用CIFAR10等内置数据集,学习如何编写自定义的 Dataset 类来加载自己的数据。

  2. 调试技巧:

    • 多用 print(x.shape) 来检查张量维度,确保层与层之间匹配。

    • 使用TensorBoard或Matplotlib可视化训练过程、特征图,直观理解模型行为。

    • 遇到NaN损失(梯度爆炸):尝试减小学习率、添加梯度裁剪 (torch.nn.utils.clip_grad_norm_)。

  3. 项目实践:

    • Kaggle:参加入门级的图像分类比赛(如Dogs vs Cats)。

    • 复现论文:找一篇简单的CV领域论文(如2015年左右的),尝试复现其基本结果。

    • 部署尝试:学习使用ONNX、TorchScript或简单的Flask/FastAPI将模型封装成API,这是工业应用的重要一环。

  4. 理论基础:

    • 在实践的同时,务必补充《深度学习》(花书)、《统计学习方法》等理论基础,理解反向传播、优化、正则化等概念的数学本质。

大概就是可以试着去GitHub找一些代码复现一下,以及可以继续看李沐老师的深度学习了

http://www.dtcms.com/a/339277.html

相关文章:

  • PyTorch自动求导
  • PID调节
  • Go 进阶学习路线
  • 传统艾灸VS七彩喜艾灸机器人:同样的艾香,多了4分“巧”
  • 电脑出现‘无法启动此程序,因为计算机中丢失dll’要怎么办?2025最新的解决方法分析
  • 家庭健康能量站:微高压氧舱结合艾灸机器人,智享双重养生SPA
  • 大模型基础:Foundamentals of LLM
  • 关于物理世界、感知世界、认知世界与符号世界统一信息结构的跨领域探索
  • 最近常问的70道vue相关面试题
  • 豆包1.5 Vision Lite 对比 GPT-5-min,谁更适合你?实测AI模型选型利器 | AIBase
  • 【Langchain系列七】Langchain+FastAPI(字符串输出与OpenAI规范流式输出)+FastGPT
  • 《若依》项目结构分析
  • 温故而知新 再看设计模式
  • 2025.8.19总结
  • 防抖技术(一)——OIS光学防抖技术详解
  • 块存储 对象存储 文件存储的区别与联系
  • plantsimulation知识点25.8.19 工件不在RGV中心怎么办?
  • 技术详解及案例汇总|JY-V620半导体RFID读写器在晶圆盒追踪中的使用
  • Aiseesoft iPhone Unlocker:轻松解决iPhone锁屏问题
  • 量子计算和超级计算机将彻底改变技术
  • 重置iPhone会删除所有内容吗? 详细回答
  • 【Cocos】2D关节组件
  • canoe发送接收报文不通到底是接口问题还是配置问题如何处理
  • Codeforces 斐波那契立方体
  • 【Pycharm虚拟环境中安装Homebrew,会到系统中去吗】
  • k8sday11服务发现(2/2)
  • 机器学习(决策树2)
  • CMake进阶: CMake Modules---简化CMake配置的利器
  • C# NX二次开发:操作按钮控件Button和标签控件Label详解
  • 机器学习之决策树:从原理到实战(附泰坦尼克号预测任务)