当前位置: 首页 > news >正文

Day 40:训练和测试的规范写法

1. 单通道图片的规范准备工作

首先,回顾一下基础设置。用 MNIST 数据集(单通道灰度图像)作为例子。代码开头导入必要的库,并设置设备和随机种子,确保结果可复现。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

这里,transforms.Compose 用于数据预处理:转为张量、归一化(MNIST 的均值和标准差是固定的)。然后加载数据集,分成训练集和测试集,用 DataLoader 打包成批次(batch_size=64)。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

模型定义是个简单的 MLP:展平图像、两层全连接 + ReLU。损失函数用交叉熵,优化器选 Adam。

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return xmodel = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

这些都是老生常谈,但规范写法从这里开始体现:数据和模型分离,便于后续扩展。

2. 训练函数的封装:逻辑隔离,参数复用

以前写鸢尾花 MLP 时,训练代码直接扔在主程序里,看起来乱糟糟的。现在,我们用函数 train 封装整个过程。为什么这么做?一是参数(如 epochs)调整方便,不用翻代码;二是复用性强,以后多模型对比时,直接调用就好。
函数里记录每个 batch 的损失(iteration 级),每 100 batch 打印一次,epoch 结束时测试并绘图。注意,早停策略暂不加(需要验证集),但测试函数独立封装。

def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)return epoch_test_acc

测试函数 test 也很干净:eval 模式、无梯度计算,计算平均损失和准确率。

def test(model, test_loader, criterion, device):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

绘图函数绘制 iteration 损失曲线,更直观地看训练过程。

def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()

3. 运行与效果分析

直接调用训练:epochs=2(实际可调大),看看效果。

epochs = 2
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

运行后,会看到每 100 batch 的损失打印,epoch 结束的准确率,以及损失曲线图。MNIST 上 MLP 能到 96%+,但如果换 CIFAR-10,准确率可能卡在 50% 左右。为什么?MLP 忽略图像空间结构,全连接层参数爆炸,容易过拟合。

@浙大疏锦行

http://www.dtcms.com/a/340532.html

相关文章:

  • 【深度学习新浪潮】空天地数据融合技术在城市三维重建中的应用
  • 学习嵌入式的第二十二天——数据结构——双向链表
  • 前端图片压缩实战:体积直降 80%,LCP 提升 2 倍
  • 数字化图书管理系统设计实践(java)
  • 【考研408数据结构-04】 栈与队列:受限的线性表
  • Java FTPClient详解:高效文件传输指南
  • 用好 Elasticsearch Ruby 传输层elastic-transport
  • Redisson3.14.1及之后连接阿里云redis代理模式,使用分布式锁:ERR unknown command ‘WAIT‘
  • python中selenium怎么使用
  • KUKA机器人KUKA.ConveyorTech传送带跟踪程序举例解析
  • Python采集易贝(eBay)商品详情API接口,json数据返回
  • 今日科技风向|从AI芯片定制到阅兵高科技展示——聚焦技术前沿洞察
  • MySQL 数据库知识点与注意事项总结
  • spring整合JUnit
  • 阿里云ECS服务器的公网IP地址
  • WPF Alert弹框控件 - 完全使用指南
  • Non-stationary Diffusion For Probabilistic Time Series Forecasting论文阅读笔记
  • LoRa 网关与节点组网方案
  • 基于Java虚拟线程的高并发作业执行框架设计与性能优化实践指南
  • 【Bluedroid】A2DP Source 端会话启动流程与核心机制解析(btif_a2dp_source_start_session)
  • UIGestureRecognizer 各个子类以及其作用
  • iOS开发之UICollectionView为什么需要配合UICollectionViewFlowLayout使用
  • 氯化钇:科技与高性能材料的核心元素
  • C++高频知识点(三十)
  • 嵌入式音频开发(3)- AudioService核心功能
  • 机器学习数学基础与商业实践指南:从统计显著性到预测能力的认知升级
  • Node.js中的Prisma应用:现代数据库开发的最佳实践
  • 河南萌新联赛2025第六场 - 郑州大学
  • Java:将视频上传到腾讯云并通过腾讯云点播播放
  • 【Task02】:四步构建简单rag(第一章3节)