当前位置: 首页 > news >正文

用 PyTorch 轻松实现 MNIST 手写数字识别

用 PyTorch 轻松实现 MNIST 手写数字识别

引言

在深度学习领域,MNIST 数据集就像是 “Hello World” 级别的经典入门项目。它包含大量手写数字图像及对应标签,非常适合新手学习如何搭建和训练神经网络模型。本文将基于 PyTorch 框架,详细拆解如何完成 MNIST 手写数字识别任务,让你轻松入门深度学习实践。

1. 数据加载与预处理

首先,我们利用torchvision库中的datasets.MNIST函数来加载 MNIST 数据集。代码如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

在这段代码中,root="data"指定了数据集的存储路径;train=True表示加载训练集,train=False则用于加载测试集;download=True确保如果本地没有数据集,会自动从网络下载;transform=ToTensor()将图像数据转换为 PyTorch 能够处理的张量格式,同时将像素值从 0-255 归一化到 0-1 区间 。

为了直观感受数据集,我们还可以绘制几张图像:

python

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

上述代码从训练集中选取了 9 张图像,绘制出图像及其对应的标签,方便我们对数据有更直观的认识。

接下来,使用DataLoader对数据集进行封装,以方便后续按批次训练和测试:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

batch_size=64表示每次训练或测试时,模型会同时处理 64 个样本,这有助于提高计算效率和训练稳定性。

2. 模型构建

我们定义一个简单的全连接神经网络类NeuralNetwork

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28 * 28, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.relu(x)x = self.hidden2(x)x = torch.relu(x)x = self.out(x)return x

__init__函数中,nn.Flatten()用于将输入的二维图像张量展平为一维向量;nn.Linear()是全连接层,我们依次构建了两个隐藏层和一个输出层,输出层有 10 个神经元,对应 0-9 这 10 个数字类别。在forward函数中,定义了数据的前向传播过程,包括线性变换和激活函数torch.relu()的应用,激活函数能为模型引入非线性,使其能够学习更复杂的模式。

然后将模型移动到合适的设备(GPU 或 CPU)上:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = NeuralNetwork().to(device)
print(model)

3. 训练与测试

3.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

在训练函数中,首先通过model.train()将模型设置为训练模式,然后遍历数据加载器中的每一批数据。对于每一批数据,将数据和标签移动到指定设备上,进行前向传播计算预测值,通过损失函数nn.CrossEntropyLoss()计算预测值与真实标签之间的损失。接着使用optimizer.zero_grad()清空梯度,loss.backward()进行反向传播计算梯度,最后optimizer.step()根据计算得到的梯度更新模型参数。每训练 100 个批次,打印当前的损失值。

3.2 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct

测试函数中,先将模型设置为评估模式model.eval(),关闭一些在训练过程中使用的操作(如 Dropout)。在测试过程中,不需要计算梯度,因此使用with torch.no_grad()。通过遍历测试数据加载器,计算模型预测结果与真实标签之间的损失,并统计正确预测的样本数量,最后计算平均损失和准确率并打印输出。

3.3 执行训练与测试

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

我们选择交叉熵损失函数nn.CrossEntropyLoss()作为损失计算方式,Adam 优化器torch.optim.Adam()来更新模型参数,学习率设置为 0.01。通过循环 10 个训练周期,不断训练模型,训练完成后进行测试,得到模型在测试集上的准确率和平均损失。

4. 总结

通过上述步骤,我们基于 PyTorch 完成了 MNIST 手写数字识别任务。从数据加载、模型构建,到训练和测试,每个环节都紧密相连。这个项目不仅让我们熟悉了 PyTorch 的基本使用流程,也对神经网络的工作原理有了更直观的认识。后续我们可以通过调整模型结构、超参数等方式进一步优化模型性能,探索更多深度学习的奥秘。

相关文章:

  • 【MySQL】索引(重要)
  • [Java]Java的三个阶段
  • C++类_成员函数指针
  • vae笔记
  • 修复笔记:SkyReels-V2项目中的 from_config 警告
  • 学习黑客Linux权限
  • bc 命令
  • 31.软件时序控制方式抗干扰
  • 四年级数学知识边界总结思考-上册
  • FPGA----基于ZYNQ 7020实现EPICS通信系统
  • CATIA高效工作指南——曲面设计篇(一)
  • [GESP202503 四级] 二阶矩阵c++
  • [python]非零基础上手之文件操作
  • 【人工智能学习笔记 二】 MCP 和 Function Calling的区别与联系
  • 动态规划(5)路径问题--剑指offer -珠宝的最大值
  • 【AI论文】Phi-4-reasoning技术报告
  • nginx 核心功能 02
  • 软件架构方之旅(5):SAAM 在软件技术架构评估中的应用与发展研究
  • 基于python生成taskc语言文件--时间片轮询
  • 0.0973585?探究ts_rank的score为什么这么低
  • 五一上海楼市热闹开局:售楼处全员到岗,热门楼盘连续触发积分
  • 五一车市消费观察:政策赋能、企业发力,汽车消费火热
  • 重庆市大渡口区区长黄红已任九龙坡区政协党组书记
  • A股2024年年报披露收官,四分之三公司盈利
  • 2025年第一批“闯中人”已经准备好了
  • 200枚篆刻聚焦北京中轴线,“印记”申遗往事