当前位置: 首页 > news >正文

使用全连接神经网络训练和预测MNIST以及CIFAR10

深度学习 — 使用全连接神经网络训练和预测 MNIST以及CIFAR10


文章目录

  • 深度学习 --- 使用全连接神经网络训练和预测 MNIST以及CIFAR10
  • 一,使用全连接神经网络训练和预测 MNIST
  • 二,使用全连接神经网络训练和预测CIFAR10
  • 三,所需链接
      • ✅ MNIST 数据集
      • ✅ CIFAR-10 数据集


一,使用全连接神经网络训练和预测 MNIST

'''
使用全连接神经网络训练和预测MNIST手写数字数据集
1.数据准备,通过数据集加载官方提供的MNIST数据集
2.构建神经网络
3.实现训练方法,使用交叉熵损失函数和Adam优化器
4.实现验证方法
5.通过测试图片预测结果
'''import torch
from torchvision import datasets,transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import osdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def build_data():transform =transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor(),#转换为张量])#训练数据集dataset=datasets.MNIST(root="./datasets",train=True,download=True,transform=transform)#测试数据集val_dataset=datasets.MNIST(root="./datasets",train=False,download=True,transform=transform)#构建训练数据加载器train_dataloader=DataLoader(dataset,batch_size=64,shuffle=True)#构建测试数据加载器val_dataloader=DataLoader(val_dataset,batch_size=64,shuffle=True)return train_dataloader,val_dataloader#构建神经网络
class MNISTNet(nn.Module):def __init__(self,in_features,out_features):super().__init__()self.fc1=nn.Linear(in_features,128)self.bn1=nn.BatchNorm1d(128)self.relus=nn.ReLU()self.fc2=nn.Linear(128,64)self.bn2=nn.BatchNorm1d(64)self.relus=nn.ReLU()self.fc3=nn.Linear(64,out_features)def forward(self,x):x = x.view(x.size(0), -1)x=self.fc1(x)x=self.bn1(x)x=self.relus(x)x=self.fc2(x)x=self.bn2(x)x=self.relus(x)x=self.fc3(x)return x#训练
def train(model,train_dataloader,val_dataloader,lr,epochs):model.train()#损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr,betas=(0.9,0.999),eps=1e-08,weight_decay=0.01)for epoch in range(epochs):correct=0for tx,ty in train_dataloader:tx, ty = tx.to(device), ty.to(device)y_pred=model(tx)loss=criterion(y_pred,ty)optimizer.zero_grad()loss.backward()optimizer.step()_,pred=torch.max(y_pred.data,dim=1)correct+=(pred==ty).sum().item()acc=correct/len(train_dataloader.dataset)print(f'Epoch:{epoch+1}/{epochs},loss:{loss.item():.4f},acc:{acc:.4f}')#验证
def eval(model,val_dataloader):model.eval()crirerion=nn.CrossEntropyLoss()correct=0for vx,vy in val_dataloader:vx,vy = vx.to(device), vy.to(device)with torch.no_grad():y_pred=model(vx)loss=crirerion(y_pred,vy)_,pred=torch.max(y_pred,dim=1)correct+=(pred==vy).sum().item()acc=correct/len(val_dataloader.dataset)acc=correct/len(val_dataloader.dataset)print(f'loss:{loss.item():.4f},acc:{acc:.4f}')#模型保存
def save_model(model,path):torch.save(model.state_dict(), path)
#模型加载
def load_model(path):model = MNISTNet(1*28*28, 10)model.load_state_dict(torch.load(path, map_location=device, weights_only=True))return model#测试
# 修改 predict 函数:确保模型和输入在同一设备
def predict(model, test_img, model_path):transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),])image = Image.open(test_img).convert('L')#转换为灰度图t_img = transform(image).unsqueeze(0).to(device)  # 直接放到目标设备model = load_model(model_path)#加载模型model.to(device)# 显式把模型移到同一设备model.eval()#测试模式with torch.no_grad():y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'预测结果为:{pred.item()}')if __name__=='__main__':train_dataloader,val_dataloader=build_data()#获取数据model=MNISTNet(1*28*28,10).to(device)#构建模型train(model,train_dataloader,val_dataloader,lr=0.001,epochs=2)#训练eval(model,val_dataloader)#验证os.makedirs('./model', exist_ok=True)save_model(model,'./model/mnist_model.pt')#模型保存predict(model,'torch-fcnn/fcnn-demo/datasets/3.png','./model/mnist_model.pt')#测试

二,使用全连接神经网络训练和预测CIFAR10

import torch
from torchvision import datasets,transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from PIL import Image
import osdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def build_data():#转换为张量transform=transforms.Compose([transforms.Resize((32,32)),transforms.RandomHorizontalFlip(p=0.5),  transforms.RandomRotation(15),           transforms.ColorJitter(brightness=0.2, contrast=0.2),  transforms.ToTensor(),#转换为张量transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])#加载训练cifar-10数据集dataset=datasets.CIFAR10(root="./datasets",train=True,download=True,transform=transform)#测试数据集val_dataset=datasets.CIFAR10(root="./datasets",train=False,download=True,transform=transform)#构建训练数据加载器train_dataloader=DataLoader(dataset,batch_size=64,shuffle=True)#构建测试数据加载器val_dataloader=DataLoader(val_dataset,batch_size=64,shuffle=True)return train_dataloader,val_dataloader#构建神经网络模型
class CIFAR10Net(nn.Module):def __init__(self,in_features,out_features):super().__init__()self.fc1=nn.Linear(3072,512)self.bn1=nn.BatchNorm1d(512)self.relu1=nn.ReLU()self.fc2=nn.Linear(512,256)self.bn2=nn.BatchNorm1d(256)self.relu2=nn.ReLU()self.fc3=nn.Linear(256,128)self.bn3=nn.BatchNorm1d(128)self.relu3=nn.ReLU()self.fc4=nn.Linear(128,64)self.bn4=nn.BatchNorm1d(64)self.relu4=nn.ReLU()self.fc5=nn.Linear(64,out_features)def forward(self,x):x=x.view(x.size(0),-1)x=self.fc1(x)x=self.bn1(x)x=self.relu1(x)x=self.fc2(x)x=self.bn2(x)x=self.relu2(x)x=self.fc3(x)x=self.bn3(x)x=self.relu3(x)x=self.fc4(x)x=self.bn4(x)x=self.relu4(x)x=self.fc5(x)return x#训练
def train(model,train_dataloader,val_dataloader,lr,epochs):model.train()#损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr,betas=(0.9,0.999),eps=1e-08,weight_decay=0.01)for epoch in range(epochs): correct=0for tx,ty in train_dataloader:tx, ty = tx.to(device), ty.to(device)y_pred=model(tx)loss=criterion(y_pred,ty)optimizer.zero_grad()loss.backward()optimizer.step()_,pred=torch.max(y_pred.data,dim=1)correct+=(pred==ty).sum().item()acc=correct/len(train_dataloader.dataset)print(f'Epoch:{epoch+1}/{epochs},loss:{loss.item():.4f},acc:{acc:.4f}')#验证
def eval(model, val_dataloader):model.eval()criterion = nn.CrossEntropyLoss()total_loss = 0.0correct = 0total_samples = 0with torch.no_grad():for vx, vy in val_dataloader:vx, vy = vx.to(device), vy.to(device)y_pred = model(vx)loss = criterion(y_pred, vy)total_loss += loss.item() * vx.size(0)  _, pred = torch.max(y_pred, dim=1)correct += (pred == vy).sum().item()total_samples += vy.size(0)avg_loss = total_loss / total_samplesacc = correct / total_samplesprint(f'Val Loss: {avg_loss:.4f}, Val Acc: {acc:.4f}')return avg_loss  #模型保存
def save_model(model,path):torch.save(model.state_dict(),path)#模型加载
def load_model(path):model=CIFAR10Net(1*32*32,10).to(device)model.load_state_dict(torch.load(path,map_location=device,weights_only=True))model.eval()return model#测试模型
def predict(model,test_img,model_path):transform=transforms.Compose([transforms.Resize((32,32)),transforms.RandomHorizontalFlip(p=0.5),  transforms.RandomRotation(15),           transforms.ColorJitter(brightness=0.2, contrast=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])image = Image.open(test_img).convert('RGB')t_img=transform(image).unsqueeze(0).to(device)model=load_model(model_path)model.to(device)model.eval()classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']with torch.no_grad():   y_pred=model(t_img)_,pred=torch.max(y_pred,dim=1)print(f'预测类别编号:{pred.item()},类别名称:{classes[pred.item()]}')if __name__ == '__main__':train_dataloader,val_dataloader=build_data()model=CIFAR10Net(1*32*32,10).to(device)train(model,train_dataloader,val_dataloader,lr=0.0001,epochs=50)eval(model,val_dataloader)os.makedirs('models',exist_ok=True)save_model(model,'models/cifar10_model.pt')predict(model,'torch-fcnn/fcnn-demo/datasets/100.jpg','models/cifar10_model.pt')

三,所需链接

以下是 MNISTCIFAR-10 数据集的直接下载链接,均为官方或常用可信来源:


✅ MNIST 数据集

下载方式链接
官方页面(含所有文件)https://yann.lecun.com/exdb/mnist/
百度网盘(JPEG格式)https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n 提取码:xl8n
百度网盘(原始格式)https://pan.baidu.com/s/1jAPlVKLYamJn6I63GD6HDg?pwd=azq2 提取码:azq2

✅ CIFAR-10 数据集

下载方式链接
官方压缩包(Python版)http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
官方主页(含多种版本)http://www.cs.toronto.edu/~kriz/cifar.html

如需快速使用,可直接用 PyTorch 或 TensorFlow 提供的 API 自动下载(首次运行会自动缓存到本地):

# PyTorch 示例
from torchvision.datasets import MNIST, CIFAR10
MNIST(root='./data', train=True, download=True)
CIFAR10(root='./data', train=True, download=True)

建议使用API下载

http://www.dtcms.com/a/310688.html

相关文章:

  • 十、SpringBootWeb快速入门-入门案例
  • 玻尔兹曼分布与玻尔兹曼探索
  • 户外广告牌识别误检率↓78%!陌讯动态感知算法实战解析
  • 力扣面试150题--数字范围按位与
  • 【文章素材】ACID 原子性与数据库
  • 五自由度机械臂阻抗控制下的力跟踪
  • 神经网络学习笔记
  • 台式机 Server 20.04 CUDA11.8
  • JAVA,Filter和Interceptor
  • ThreadLocal总结
  • 基于倍增的LCA + kruskal重构树 + 并查集
  • 可编辑234页PPT | 某制造集团供应链流程分析和数字化转型解决方案
  • JavaScript 语句和函数
  • ensp防火墙安全策略实验
  • 【全网首个公开VMware vCenter 靶场环境】 Vulntarget-o 正式上线
  • Linux权限提升
  • shell编程练习,实现循环创建账户、测试主机连通性、批量修改主机root密码等功能
  • Linux 用户与组管理:从配置文件到实操命令全解析
  • Lecture 7: Processes 4, Further Scheduling
  • 嵌入式系统中常用通信协议
  • 高压大电流与低压大电流电源的设计难点
  • QT中重写事件过滤失效(返回了多个事件)
  • Jetpack Compose Column组件之focusProperties修饰符
  • 基于C#和NModbus4库实现的Modbus RTU串口通信
  • 【工具分享】模拟接口请求响应的Chrome插件ModResponse
  • 光伏运维数据透明化,发电量提高45%
  • Cursor免费使用工具
  • 配置多数据源dynamic-datasource 开箱即用方案​
  • ubuntu使用man手册中文版办法
  • 同品牌的系列广告要如何保证宣传的连贯性?