当前位置: 首页 > news >正文

实战:用 PyTorch 复现一个 3 层全连接网络,训练 MNIST,达到 95%+ 准确率

1. 使用 Anaconda 创建一个新环境,包括 python 和 与你显卡对应的 torch

2. PyCharm(2025.1.3.1)绑定 Conda 环境-CSDN博客

3. 

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 一次给模型看多少张图片
BATCH_SIZE = 64
# 把全部训练数据重复看多少遍
EPOCHS = 10
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)# 原始数据集中,一张 MNIST 图片的形状是 (1, 28, 28) ← 1 个通道(灰度),高 28,宽 28。
# 当 DataLoader 按 batch_size=64 打包后,它把 64 张这样的图片堆在一起,形成一个新的 4 维张量,形状变成 (64, 1, 28, 28)
# shuffle = True 的作用:在每个 epoch 开始时,把训练集里的 60 000 张图片顺序彻底打乱一次。
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE)# 搭建神经网络:把图片拉成一条长条 → 过 128 个神经元 → 再过 64 个神经元 → 最后给出 10 个数字的得分
class Net(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(784, 128), nn.ReLU(),nn.Linear(128, 64),  nn.ReLU(),nn.Linear(64, 10))def forward(self, x):return self.net(x)model = Net().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)# 训练
for epoch in range(1, EPOCHS + 1):model.train()pbar = tqdm(train_loader, desc=f"Epoch {epoch}")for x, y in pbar:x, y = x.to(DEVICE), y.to(DEVICE)optimizer.zero_grad()loss = criterion(model(x), y)loss.backward()optimizer.step()pbar.set_postfix(loss=loss.item())model.eval()
correct = total = 0
with torch.no_grad():for x, y in test_loader:x, y = x.to(DEVICE), y.to(DEVICE)pred = model(x).argmax(1)correct += (pred == y).sum().item()total += y.size(0)
print(f"Test Accuracy: {100*correct/total:.2f}%")

4. 运行

http://www.dtcms.com/a/325416.html

相关文章:

  • 软考高级资格推荐与选择建议
  • 大语言模型(LLM)核心概念与应用技术全解析:从Prompt设计到向量检索
  • STM32蓝牙模块驱动开发
  • 什么是结构化思维?什么是结构化编程?
  • 获取MaixPy系列开发板机器码——MaixHub 模型下载机器码获取方法
  • 【Python】在rk3588开发板排查内存泄漏问题过程记录
  • 视频前处理技术全解析:从基础到前沿
  • DreaMoving:基于扩散模型的可控视频生成框架
  • 安全合规4--下一代防火墙组网
  • GaussDB 数据库架构师修炼(十三)安全管理(1)-账号的管理
  • vue+flask基于规则的求职推荐系统
  • CentOS7搭建安全FTP服务器指南
  • 【安全发布】微软2025年07月漏洞通告
  • C语言如何安全的进行字符串拷贝
  • MQTT:Vue集成MQTT
  • GaussDB安全配置全景指南:构建企业级数据库防护体系
  • 【vue(一))路由】
  • uncalled4
  • 昆仑万维SkyReels-A3模型发布:照片开口说话,视频创作“一键改台词”
  • 使用行为树控制机器人(二) —— 黑板
  • 哈希、存储、连接:使用 ES|QL LOOKUP JOIN 的日志去重现代解决方案
  • Logistic Loss Function|逻辑回归代价函数
  • 实习学习记录
  • 集成电路学习:什么是URDF Parser统一机器人描述格式解析器
  • ttyd终端工具移植到OpenHarmony
  • 工业相机与智能相机的区别
  • 5G与云计算对代理IP行业的深远影响
  • 用 Python 绘制企业年度财务可视化报告 —— 从 Excel 到 9 种图表全覆盖
  • nvm安装详细教程(卸载旧的nodejs,安装nvm、node、npm、cnpm、yarn及环境变量配置)
  • 论文中PDF的公式如何提取-公式提取