当前位置: 首页 > news >正文

用 PyTorch 实现 MNIST 手写数字识别:从入门到实践

手写数字识别是机器学习领域的经典入门案例,而 MNIST 数据集则是这个领域的 "Hello World"。本文将带你从零开始,使用 PyTorch 构建一个两层神经网络,完成 MNIST 手写数字的识别任务。无论你是机器学习新手还是想复习基础,这篇教程都能帮助你理解神经网络的基本原理和实现过程。

什么是 MNIST 数据集?

MNIST(Modified National Institute of Standards and Technology)数据集包含 60,000 个训练样本和 10,000 个测试样本,均为 28×28 像素的灰度手写数字图像(0-9)。它之所以成为入门经典,是因为:

  • 数据规模适中,不需要超级计算机也能训练
  • 任务明确(10 分类问题),评价指标简单(准确率)
  • 预处理简单,无需复杂的图像增强

环境准备与依赖库

本次实现基于 PyTorch 框架,需要安装以下依赖:

  • numpy:数值计算
  • torch:PyTorch 核心库
  • torchvision:包含 MNIST 数据集和图像处理工具
  • matplotlib:可视化工具
  • tensorboard:训练过程可视化

安装命令:

bash

pip install numpy torch torchvision matplotlib tensorboard

实现步骤详解

1. 超参数配置

在开始之前,我们先集中定义所有可配置的超参数,方便后续调试和优化:

python

运行

config = {"train_batch_size": 64,  # 训练批次大小"test_batch_size": 128,   # 测试批次大小"learning_rate": 0.01,    # 初始学习率"num_epochs": 20,         # 训练轮次"in_dim": 28 * 28,        # 输入维度(28x28像素)"n_hidden_1": 300,        # 第一个隐藏层神经元数"n_hidden_2": 100,        # 第二个隐藏层神经元数"out_dim": 10,            # 输出维度(10个数字类别)"log_dir": "logs",        # TensorBoard日志目录"data_root": "../data"    # 数据保存路径
}

超参数的选择对模型性能影响很大,后续可以通过调整这些参数来优化模型。

2. 数据加载与预处理

数据预处理是机器学习 pipeline 中至关重要的一步,直接影响模型性能:

python

运行

def load_data(data_root, train_batch_size, test_batch_size):# 定义预处理流程transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]transforms.Normalize([0.5], [0.5])  # 标准化到[-1,1]])# 加载训练集和测试集train_dataset = MNIST(root=data_root,train=True,transform=transform,download=True)test_dataset = MNIST(root=data_root,train=False,transform=transform)# 数据加载器train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True  # 训练时打乱数据)test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False  # 测试时无需打乱)return train_loader, test_loader

预处理说明

  • ToTensor():将图像从 PIL 格式转换为 PyTorch 张量,并将像素值从 [0,255] 缩放到 [0,1]
  • Normalize():标准化处理,公式为(x - mean) / std,这里将数据调整为均值 0、标准差 0.5,最终范围为 [-1,1]
  • DataLoader:提供批量加载、打乱数据、多线程加载等功能

3. 数据可视化

加载数据后,我们可以随机可视化几个样本,验证数据加载是否正确:

python

运行

def visualize_samples(test_loader, num_samples=6):examples = enumerate(test_loader)batch_idx, (example_data, example_targets) = next(examples)fig = plt.figure(figsize=(8, 4))for i in range(num_samples):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(example_data[i][0], cmap='gray')plt.title(f"标签: {example_targets[i].item()}")plt.xticks([])plt.yticks([])plt.show()

4. 神经网络模型设计

我们将构建一个包含两个隐藏层的全连接神经网络:

python

运行

class MNISTNet(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(MNISTNet, self).__init__()self.flatten = nn.Flatten()  # 展平层# 第一个隐藏层(带批归一化)self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))# 第二个隐藏层(带批归一化)self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))# 输出层self.out = nn.Linear(n_hidden_2, out_dim)def forward(self, x):x = self.flatten(x)  # 展平为1D向量x = F.relu(self.layer1(x))  # 第一层+ReLU激活x = F.relu(self.layer2(x))  # 第二层+ReLU激活x = F.softmax(self.out(x), dim=1)  # 输出层+softmaxreturn x

模型设计要点

  • nn.Flatten():将 28×28 的二维图像转换为 784 维的一维向量
  • 批归一化(BatchNorm1d):加速训练收敛,提高稳定性
  • ReLU 激活函数:解决梯度消失问题,引入非线性
  • Softmax 输出:将最后一层输出转换为概率分布(总和为 1)

5. 训练与评估流程

训练过程是模型学习的核心,我们需要定义损失函数、优化器,并实现完整的训练循环:

python

运行

def train_and_evaluate(config):# 加载数据train_loader, test_loader = load_data(**config)visualize_samples(test_loader)# 设备配置(自动选择GPU或CPU)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f"使用设备: {device}")# 初始化模型、损失函数和优化器model = MNISTNet(** config).to(device)criterion = nn.CrossEntropyLoss()  # 交叉熵损失optimizer = optim.SGD(model.parameters(),lr=config["learning_rate"],momentum=0.9  # 动量加速收敛)# 训练循环writer = SummaryWriter(log_dir=config["log_dir"])  # TensorBoard日志losses, eval_acces = [], []for epoch in range(config["num_epochs"]):# 训练阶段model.train()  # 训练模式train_loss, train_acc = 0.0, 0.0# 学习率调整if epoch % 5 == 0 and epoch != 0:optimizer.param_groups[0]['lr'] *= 0.9print(f"学习率调整为: {optimizer.param_groups[0]['lr']:.6f}")for img, label in train_loader:img, label = img.to(device), label.to(device)# 前向传播output = model(img)loss = criterion(output, label)# 反向传播与优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 计算梯度optimizer.step()       # 更新参数# 计算损失和准确率train_loss += loss.item()_, pred = torch.max(output, 1)train_acc += (pred == label).sum().item() / img.size(0)# 评估阶段model.eval()  # 评估模式eval_acc = 0.0with torch.no_grad():  # 禁用梯度计算for img, label in test_loader:img, label = img.to(device), label.to(device)output = model(img)_, pred = torch.max(output, 1)eval_acc += (pred == label).sum().item() / img.size(0)# 记录指标avg_train_loss = train_loss / len(train_loader)avg_train_acc = train_acc / len(train_loader)avg_eval_acc = eval_acc / len(test_loader)losses.append(avg_train_loss)eval_acces.append(avg_eval_acc)print(f"Epoch [{epoch+1}/{config['num_epochs']}] | "f"训练损失: {avg_train_loss:.4f}, 训练准确率: {avg_train_acc:.4f} | "f"测试准确率: {avg_eval_acc:.4f}")# 可视化训练结果visualize_training(losses, eval_acces)writer.close()

训练关键步骤解析

  1. 设备选择:自动检测并使用 GPU(如有),大幅加速训练
  2. 损失函数:使用交叉熵损失(CrossEntropyLoss),适合多分类问题
  3. 优化器:带动量的 SGD,动量有助于加速收敛并跳出局部最优
  4. 学习率调度:每 5 个 epoch 将学习率乘以 0.9,后期精细化优化
  5. 训练模式与评估模式model.train()model.eval()控制批归一化等层的行为
  6. 梯度管理optimizer.zero_grad()清空梯度,loss.backward()计算梯度,optimizer.step()更新参数

6. 结果可视化

训练结束后,我们可以可视化损失和准确率曲线,直观了解模型性能变化:

python

运行

def visualize_training(losses, eval_acces):fig = plt.figure(figsize=(10, 4))# 损失曲线plt.subplot(1, 2, 1)plt.title('训练损失')plt.plot(np.arange(len(losses)), losses, 'b-')plt.xlabel('Epoch')plt.ylabel('Loss')# 准确率曲线plt.subplot(1, 2, 2)plt.title('测试准确率')plt.plot(np.arange(len(eval_acces)), eval_acces, 'g-')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.tight_layout()plt.show()

模型优化方向

如果想进一步提高准确率,可以尝试以下方法:

  1. 增加网络深度或宽度(但要注意防止过拟合)
  2. 使用更先进的优化器(如 Adam 替代 SGD)
  3. 调整学习率调度策略
  4. 添加 dropout 层防止过拟合
  5. 尝试数据增强(旋转、平移等)

总结

本文详细介绍了使用 PyTorch 实现 MNIST 手写数字识别的完整流程,包括数据加载与预处理、模型设计、训练循环和结果可视化。通过这个案例,我们可以掌握神经网络的基本原理和实现方法:

  • 数据预处理对模型性能的重要性
  • 神经网络的基本组成(线性层、激活函数、批归一化)
  • 训练过程的核心步骤(前向传播、损失计算、反向传播、参数更新)
  • 如何评估模型性能并进行可视化分析
http://www.dtcms.com/a/478326.html

相关文章:

  • 设计模式篇之 代理模式 Proxy
  • 智联招聘网站建设情况wordpress 注册 密码
  • Mobius Protocol:在“去中心化”逐渐被遗忘的时代,重建秩序的尝试
  • 网站制作公司费用wordpress 宋体、
  • 长宁怎么做网站优化好住房城乡建设门户网站
  • MySQL InnoDB Cluster 高可用集群部署与应用实践(下)
  • commons-rng(伪随机数生成)
  • qemu 串口模拟输入的整个流程
  • 在git commit时利用AI自动生成并填充commit信息
  • 【完整源码+数据集+部署教程】可回收金属垃圾检测系统源码和数据集:改进yolo11-AggregatedAtt
  • HakcMyVM-Crack
  • emmc extcsd寄存器
  • 利用径向柱图探索西班牙语学习数据
  • wordpress建淘宝客网站吗上海网站制作最大的公司
  • 定制网站平台的安全设计房地产公司网站建设
  • 筛法(Sieve Method)简介
  • 【论文阅读】基于指数-高斯混合网络的视频观看时间预测的多粒度分布建模-小红书recsys25
  • 网站开发过程模型做电影网站怎么接广告
  • 手机群控软件在游戏运营中的行为模拟技术实践
  • MySQL----触发器
  • 汕头模板建站平台朝阳市做网站
  • C8051F351-GMR工业用 8051 MCU 开发板C8051F351-GMR嵌入式处理器和控制器,适用于高精度模拟信号处理
  • [嵌入式系统-107]:语音识别的信号处理流程和软硬件职责
  • OkHttp源码解析(一)
  • 拆分PDF.html 办公小工具
  • 网站编辑用什么软件有关于网站建设类似的文章
  • 陶瓷网站制作wordpress导购主题
  • 分割——双线性插值
  • 北京天仪建设工程质量检测所网站上海做网站的公司电话
  • 建站优化一条龙新闻型网站建设