当前位置: 首页 > news >正文

深度学习:从手写数字识别案例认识pytorch框架

目录

一、PyTorch 核心优势与框架定位

二、实战基础:核心库与数据准备

1. 关键库导入与功能说明

2. MNIST 数据集加载与可视化

(1)数据集下载与封装

(2)数据集可视化(可选)

3. DataLoader:批次数据加载

三、模型构建:全连接神经网络设计

1. 模型结构设计

2. 模型代码实现与设备部署

四、训练与测试:模型学习流程

1. 核心组件配置

(1)损失函数:CrossEntropyLoss

(2)优化器:Adam

(3)训练轮次:epochs=10

2. 训练函数实现

3. 测试函数实现

4. 启动训练与测试

五、训练结果与优化方向

1. 预期训练结果

2. 模型优化方向(提升准确率至 99%+)

六、总结


在深度学习领域,PyTorch 凭借动态图机制、简洁 API 和灵活的模型构建方式,成为初学者入门与科研落地的优选框架。本文以 MNIST 手写数字识别任务为核心,结合完整 PyTorch 代码与关键理论知识,从数据加载、模型构建到训练测试,带你掌握 PyTorch 深度学习实战的核心流程。


一、PyTorch 核心优势与框架定位

在开始实战前,先明确 PyTorch 在主流深度学习框架中的独特价值。相较于 Caffe(配置繁琐、更新停滞)、TensorFlow(1.x 版本冗余、2.x 兼容性问题),PyTorch 具备两大核心优势:

  1. 低门槛易上手:支持动态图调试,代码逻辑与 Python 原生语法高度一致,无需复杂配置即可搭建模型;
  2. 灵活性与效率平衡:既支持快速原型验证(如本文 MNIST 任务),也能通过 CUDA 无缝对接 GPU,满足大规模训练需求。

正如 PyTorch 官方设计理念:“让科研与工程的边界更模糊”,其 “代码即模型” 的特性,非常适合从基础任务入门深度学习。


二、实战基础:核心库与数据准备

1. 关键库导入与功能说明

首先导入任务所需的 PyTorch 核心库,各库功能与后续作用如下表:

导入语句核心功能实战作用
import torchPyTorch 核心库,提供 Tensor 与自动求导数据存储、梯度计算基础
from torch import nn神经网络模块,封装各类层与损失函数定义全连接层、展平层等
from torch.utils.data import DataLoader数据批次管理工具按批次加载数据,避免显存溢出
from torchvision import datasets图像数据集封装,含 MNIST 等经典数据集一键下载手写数字数据集
from torchvision.transforms import ToTensor图像数据转换工具将 PIL 图像转为 PyTorch 可处理的 Tensor
import matplotlib.pyplot as plt图像可视化库查看数据集中的手写数字样本

导入代码与基础配置(解决中文显示问题)如下:

# 导入必要的库
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt
import matplotlib# 配置matplotlib:解决中文乱码和负号显示问题
matplotlib.rcParams["font.family"] = ["SimHei"]
matplotlib.rcParams['axes.unicode_minus'] = False# 验证PyTorch环境(打印版本,确认安装成功)
print("PyTorch版本:", torch.__version__)

2. MNIST 数据集加载与可视化

MNIST 是手写数字识别的 “入门数据集”,包含 60000 张训练样本和 10000 张测试样本,每张图片为 28×28 像素的灰度图,标签为 0-9 的整数。

(1)数据集下载与封装

通过datasets.MNIST类自动下载并封装数据,核心参数说明:

  • root="data":数据保存路径(本地不存在时自动创建);
  • train=True/False:区分训练集(True)与测试集(False);
  • download=True:本地无数据时从 PyTorch 官网自动下载;
  • transform=ToTensor():将原始图像(PIL 格式)转为 Tensor,同时归一化像素值到[0,1]区间(模型训练需标准化输入)。

代码实现

# 下载并加载训练集(用于模型学习)
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)# 下载并加载测试集(用于验证模型泛化能力)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)# 查看数据集规模(MNIST标准规模:训练6万、测试1万)
print("训练集样本数量:", len(training_data))  # 输出:60000
print("测试集样本数量:", len(test_data))      # 输出:10000

(2)数据集可视化(可选)

通过matplotlib查看数据集中的手写数字,验证 “图像 - 标签” 对应关系:

# 创建3×3网格的图像窗口,大小8×8英寸
figure = plt.figure(figsize=(8, 8))
for i in range(9):# 获取第59080+i张样本(图像Tensor + 标签)img, label = training_data[i + 59080]# 添加子图(3行3列,第i+1个位置)figure.add_subplot(3, 3, i + 1)# 设置子图标题(显示标签)plt.title(f"标签: {label}")# 关闭坐标轴,避免干扰plt.axis("off")# 显示图像:squeeze()去除维度为1的通道([1,28,28]→[28,28]),cmap="gray"用灰度显示plt.imshow(img.squeeze(), cmap="gray")
# 弹出窗口展示图像
plt.show()

运行后将看到 9 张手写数字图片及对应标签,如下:

 

3. DataLoader:批次数据加载

直接遍历training_data会逐样本输入模型,效率低且易导致 GPU 显存溢出。DataLoader的核心作用是按批次打包数据,平衡训练效率与硬件资源。

核心参数与代码实现:

batch_size = 64  # 每批次加载64张图片(经验值,可根据GPU显存调整)# 训练集DataLoader:按批次加载,支持多线程(默认)
train_dataloader = DataLoader(training_data, batch_size=batch_size)
# 测试集DataLoader:与训练集批次大小一致,仅用于推理(无参数更新)
test_dataloader = DataLoader(test_data, batch_size=batch_size)# 验证DataLoader输出的数据形状(确保符合模型输入要求)
for X, y in test_dataloader:# 图像形状:[N, C, H, W] = [批次大小, 通道数, 高度, 宽度]print(f"图像Tensor形状 [N, C, H, W]: {X.shape}")  # 输出:torch.Size([64, 1, 28, 28])# 标签形状:一维Tensor,长度=批次大小,类型为整数print(f"标签形状与类型: {y.shape} {y.dtype}")      # 输出:torch.Size([64]) torch.int64break  # 仅查看第一个批次,避免循环打印

关键说明:PyTorch 模型输入需为[N, C, H, W]格式的 Tensor,其中N=batch_size(64)、C=1(灰度图单通道)、H=28W=28,上述输出验证了数据格式正确性。


三、模型构建:全连接神经网络设计

在 PyTorch 中,自定义神经网络需继承nn.Module基类,并实现两个核心方法:

  1. __init__:定义模型的层结构(如展平层、全连接层);
  2. forward:定义数据在模型中的前向传播路径(必须命名为forward,PyTorch 会自动调用)。

1. 模型结构设计

针对 MNIST 任务,设计 “展平层 + 2 个隐藏层 + 1 个输出层” 的全连接网络,结构如下:

  • 展平层(nn.Flatten):将[64,1,28,28]的 2D 图像转为[64,784]的 1D 向量(全连接层仅接受 1D 输入);
  • 隐藏层 1(nn.Linear (784, 128)):输入 784 个神经元(28×28 像素),输出 128 个神经元,进行线性变换;
  • 激活函数(torch.sigmoid):引入非线性关系(若没有激活函数,多层网络等价于单层,无法拟合复杂特征);
  • 隐藏层 2(nn.Linear (128, 256)):输入 128 个神经元,输出 256 个神经元,进一步提取特征;
  • 输出层(nn.Linear (256, 10)):输入 256 个神经元,输出 10 个神经元(对应 0-9 共 10 个类别,输出值为各分类的预测分数)。

2. 模型代码实现与设备部署

PyTorch 支持 CPU/GPU 自动切换,优先使用 GPU 可显著提升训练速度(需 NVIDIA 显卡支持 CUDA)。代码中通过torch.cuda.is_available()判断 GPU 是否可用,将模型与数据部署到同一设备(否则会报错)。

完整代码:

# 自动选择训练设备:CUDA(NVIDIA GPU)> MPS(苹果M系列GPU)> CPU
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"使用训练设备: {device}")  # 输出示例:cuda / cpu# 定义神经网络类
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()  # 调用父类nn.Module的初始化方法# 1. 展平层:将2D图像转为1D向量self.flatten = nn.Flatten()# 2. 全连接隐藏层1:784→128self.hidden1 = nn.Linear(28 * 28, 128)# 3. 全连接隐藏层2:128→256self.hidden2 = nn.Linear(128, 256)# 4. 输出层:256→10(10个类别)self.out = nn.Linear(256, 10)# 前向传播:定义数据流动路径def forward(self, x):x = self.flatten(x)          # 步骤1:展平([64,1,28,28]→[64,784])x = self.hidden1(x)          # 步骤2:隐藏层1线性变换x = torch.sigmoid(x)         # 步骤3:sigmoid激活函数(引入非线性)x = self.hidden2(x)          # 步骤4:隐藏层2线性变换x = torch.sigmoid(x)         # 步骤5:再次激活x = self.out(x)              # 步骤6:输出层(预测分数)return x# 创建模型实例,并部署到指定设备(GPU/CPU)
model = NeuralNetwork().to(device)
# 打印模型结构,验证层定义是否正确
print("\n神经网络模型结构:")
print(model)

模型结构打印输出(示例):

NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(hidden1): Linear(in_features=784, out_features=128, bias=True)(hidden2): Linear(in_features=128, out_features=256, bias=True)(out): Linear(in_features=256, out_features=10, bias=True)
)

四、训练与测试:模型学习流程

1. 核心组件配置

模型训练需三大核心组件:损失函数(衡量预测误差)、优化器(更新模型参数)、训练轮次(遍历训练集的次数)。

(1)损失函数:CrossEntropyLoss

MNIST 是 10 分类任务,选择nn.CrossEntropyLoss(交叉熵损失),其核心作用:

  • 自动将模型输出的 “预测分数” 通过 Softmax 函数转为概率分布;
  • 计算预测概率与真实标签(独热编码)的交叉熵,量化误差大小。

(2)优化器:Adam

优化器负责根据损失反向传播的梯度更新模型参数(如全连接层的权重w和偏置b)。本文选择torch.optim.Adam,相比传统 SGD(随机梯度下降),Adam 具备 “自适应学习率” 特性,收敛速度更快,泛化能力更强。

核心参数:

  • model.parameters():传入模型中所有可学习的参数(需更新的权重和偏置);
  • lr=0.001:学习率(步长),控制参数更新的幅度(过大会导致训练震荡,过小会导致收敛缓慢)。

(3)训练轮次:epochs=10

1 个轮次(epoch)表示完整遍历一次训练集。MNIST 任务较简单,10 轮即可达到较高准确率(97%+)。

配置代码:

# 1. 损失函数:多分类任务用交叉熵损失
loss_fn = nn.CrossEntropyLoss()# 2. 优化器:Adam,学习率0.001
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 3. 训练轮次:10轮
epochs = 10

2. 训练函数实现

训练函数的核心逻辑是 “遍历所有训练批次,完成前向传播→损失计算→反向传播→参数更新” 的闭环,代码与注释如下:

def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式:告诉模型“需要更新参数”(部分层如Dropout会生效)batch_num = 1  # 批次计数器,用于打印训练进度# 遍历训练集中的每个批次(X:图像数据,y:真实标签)for X, y in dataloader:# 关键步骤:将数据部署到与模型同一设备(GPU/CPU)X, y = X.to(device), y.to(device)# 1. 前向传播:计算模型预测结果pred = model(X)  # 等价于model.forward(X),PyTorch自动调用forward# 2. 计算损失:预测值与真实标签的误差loss = loss_fn(pred, y)# 3. 反向传播与参数更新(核心步骤)optimizer.zero_grad()  # 梯度清零:避免前一批次的梯度累积影响当前更新loss.backward()        # 反向传播:计算每个可学习参数的梯度(损失对参数的偏导数)optimizer.step()       # 参数更新:根据梯度和学习率调整权重和偏置# 每100个批次打印一次损失(监控训练进度,判断是否收敛)if batch_num % 100 == 1:loss_value = loss.item()  # 从Tensor中提取损失数值print(f"批次 {batch_num} | 损失: {loss_value:.4f}")batch_num += 1  # 批次计数器加1

3. 测试函数实现

测试函数的核心逻辑是 “遍历所有测试批次,计算准确率和平均损失,验证模型泛化能力”,需注意:

  • 测试时无需更新参数,因此要开启model.eval()模式;
  • 禁用梯度计算(with torch.no_grad()),节省内存并加速计算。

代码与注释如下:

def test(dataloader, model, loss_fn):model.eval()  # 开启测试模式:告诉模型“不需要更新参数”(部分层如Dropout会关闭)size = len(dataloader.dataset)  # 测试集总样本数(10000)num_batches = len(dataloader)   # 测试集总批次数(10000/64≈157)test_loss = 0  # 测试集总损失correct = 0    # 预测正确的样本数# 禁用梯度计算:测试时无需反向传播,大幅节省内存with torch.no_grad():for X, y in dataloader:# 数据部署到同一设备X, y = X.to(device), y.to(device)# 前向传播:计算预测结果pred = model(X)# 累积当前批次的损失test_loss += loss_fn(pred, y).item()# 累积正确样本数:pred.argmax(1)取预测分数最大的类别(维度1为类别维度),与y比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率avg_loss = test_loss / num_batches  # 平均损失=总损失/总批次数accuracy = correct / size           # 准确率=正确数/总样本数# 打印测试结果(准确率保留2位小数,损失保留4位小数)print(f"\n测试集性能:")print(f"准确率: {(100 * accuracy):.2f}% | 平均损失: {avg_loss:.4f}\n")print("-" * 50)

4. 启动训练与测试

遍历每个训练轮次,先调用train函数完成当前轮次的训练,再调用test函数验证模型在测试集上的性能:

print("=" * 50)
print("开始训练")
print("=" * 50)for epoch in range(epochs):# 打印当前轮次信息print(f"\n轮次 {epoch + 1}/{epochs}")print("-" * 50)# 1. 训练当前轮次(遍历所有训练批次)train(train_dataloader, model, loss_fn, optimizer)# 2. 测试当前轮次(验证泛化能力)test(test_dataloader, model, loss_fn)print("=" * 50)
print("训练结束")
print("=" * 50)

五、训练结果与优化方向

1. 预期训练结果

在 GPU(如 RTX 3060)上运行,10 轮训练后输出示例:

轮次 10/10
--------------------------------------------------
Loss: 0.0130 [number:1]
Loss: 0.0594 [number:101]
Loss: 0.0244 [number:201]
Loss: 0.0326 [number:301]
Loss: 0.0205 [number:401]
Loss: 0.0444 [number:501]
Loss: 0.0048 [number:601]
Loss: 0.0325 [number:701]
Loss: 0.0665 [number:801]
Loss: 0.0314 [number:901]测试集性能:
准确率: 97.45% | 平均损失: 0.0863--------------------------------------------------
==================================================
训练结束
==================================================

关键结论:

  • 训练损失逐步下降,说明模型在持续学习;
  • 测试准确率达 97.45%,说明模型具备良好的泛化能力(未过拟合)。

2. 模型优化方向

若需进一步提升性能,可尝试以下优化手段:

  1. 更换激活函数:将sigmoid改为ReLUtorch.relu(x)),解决 sigmoid 在深层网络中的梯度消失问题;
  2. 增加网络深度 / 宽度:添加隐藏层(如nn.Linear(256, 512))或增加神经元数量(如nn.Linear(784, 256));
  3. 数据增强:通过torchvision.transforms添加旋转(RandomRotation(5))、平移(RandomAffine(5))等操作,提升模型抗干扰能力;
  4. 正则化:添加nn.Dropout(p=0.2)层,随机 “关闭” 部分神经元,避免过拟合;
  5. 调整超参数:增大训练轮次(如 20 轮)、微调学习率(如 0.0005)或批次大小(如 128)。

http://www.dtcms.com/a/350302.html

相关文章:

  • 用 GSAP + ScrollTrigger 打造沉浸式视频滚动动画
  • 《零基础学 C 语言文件顺序读写:fputc/fgetc 到 fread/fwrite 函数详解》
  • 并行算法与向量化指令集的实战经验
  • 【Linux内核实时】实时互斥锁 - sched_rt_mutex
  • 寂静之歌 单机+联机(Songs Of Silence)免安装中文版
  • 数据存储的思考——从RocketMQ和Mysql的架构入手
  • 力扣498 对角线遍历
  • Qwen2-Plus与DeepSeek-V3深度测评:从API成本到场景适配的全面解析
  • 消费场景的构建来自哪些方面?
  • KEPServerEX——工业数据采集与通信的标准化平台
  • 处理端口和 IP 地址
  • 最新刀客IP地址信息查询系统源码_含API接口_首发
  • AI被干冒烟了
  • HTML+CSS+JavaScript实现的AES加密工具网页应用,包含完整的UI界面和加密/解密功能
  • 系统开发 Day4
  • idea官网选择具体版本的下载步骤
  • 解决VSCode终端中文乱码问题
  • Cursor入门
  • Node.js面试题及详细答案120题(43-55) -- 性能优化与内存管理篇
  • HarmonyOS 中的 @Prop:深入理解单向数据传递机制
  • Java多态大冒险:当动物们开始“造反”
  • K8s高可用:Master与候选节点核心解析
  • STM32高级定时器-输出比较模式
  • 基于周期因子的资金流入流出预测
  • 区间和使用前缀和方法得到的时间复杂度
  • 2025 高教社杯全国大学生数学建模竞赛A题B题C题D题E题思路+模型+代码+论文(9.4开赛后第一时间更新)
  • AD画PCB时不小心移除的焊盘如何恢复
  • 玩转ChatGPT:Kimi深度研究功能
  • 模拟IC设计基础系列10-virtuoso常用快捷键整理(基础操作)
  • 驱动清理工具Driver Store Explorer(驱动程序资源管理器) 中文便携版