当前位置: 首页 > news >正文

【实战1】手写字识别 Pytoch(更新中)

1 数据集

引用

import torch
from torch import nn ##nn创建神经网络
from torch.utils.data import DataLoader #DataLoader加载数据
from torchvision import datasets #datasets加载数据集
from torchvision.transforms import ToTensor #ToTensor将数据转换为张量

下载数据集 

# 下载数据
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)testing_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)

导入数据

# 加载数据
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(testing_data, batch_size=batch_size)# 打印第一个批次的数据形状 
for X, y in train_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

使用keras构建神经网络的方法

构建神经网络

定义神经网络类,继承自nn.Moudule

Pytorch使用nn.Sequentail对象创建神经网络

神经网络必须实现前向传播forward()方法

使用nn.Flatten对象将图片转成列列向量【对比Kares ,pytorch是面向对象的思路】

nn.to方法可以让模型运行在不同的硬件上

device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available() else "cpu")# 打印设备信息
print(f"Using {device} device")# 构建神经网络 2层 输入28*28,第一层128个神经元,激活函数relu,第二层10个神经元
class NeunalNetwork(nn.Module):def __init__(self):super().__init__()self.nmodel = nn.Sequential(nn.Linear(28*28, 128), # 输入层到隐藏层 nn.ReLU(), # 激活函数nn.Linear(128, 10) # 隐藏层到输出层)def forward(self, x):x = nn.Flatten(1)(x) # 将输入展平output = self.nmodel(x) #对模型训练return outputmodel = NeunalNetwork().to(device)
print(model)

编译神经网络

训练神经网络

评估神经网络

http://www.dtcms.com/a/289872.html

相关文章:

  • Codes 通过创新的重新定义 SaaS 模式,专治 “原教旨主义 SaaS 的水土不服
  • 一文速通《二次型》
  • 复盘与导出工具最新版V27.0版本更新-新增财联社涨停,自选股,表格拖拽功能
  • Agentic-R1 与 Dual-Strategy Reasoning
  • Raspi4 切换QNX系统
  • cmake语法学习笔记
  • 模电基础-开关电路和NE555
  • 【2025西门子信息化网络化决赛】模拟题+技术文档+实验vrrp standby vxlan napt 智能制造挑战赛 助力国赛!
  • Linux之conda安装使用
  • 【数据结构】栈和队列(接口超完整)
  • 实践教程:基于RV1126与ZeroTier的RTSP摄像头内网穿透与远程访问
  • InfluxDB 数据模型:桶、测量、标签与字段详解(一)
  • iptables -m connlimit导致内存不足
  • 数据存储方案h5py
  • jdk9 -> jdk17 编程方面的变化
  • Product Hunt 每日热榜 | 2025-07-20
  • Feign远程调用
  • LWJGL教程(2)——游戏循环
  • VMware中mysql无法连接端口3306不通
  • 暑假训练之动态规划---动态规划的引入
  • PrimeTime:高级片上变化(AOCV)
  • 1948. 删除系统中的重复文件夹
  • 16.TaskExecutor启动
  • Windows批量修改文件属性方法
  • pyhton基础【27】课后拓展
  • 【华为机试】169. 多数元素
  • C++ STL中迭代器学习笔记
  • day057-docker-compose案例与docker镜像仓库
  • 元学习算法的数学本质:从MAML到Reptile的理论统一与深度分析
  • Vision Transformer (ViT) 介绍