当前位置: 首页 > news >正文

【项目实战】——深度学习.全连接神经网络

目录

1.使用全连接网络训练和验证MNIST数据集

2.使用全连接网络训练和验证CIFAR10数据集


1.使用全连接网络训练和验证MNIST数据集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.bn1 = nn.BatchNorm1d(256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.bn2 = nn.BatchNorm1d(128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.bn1(self.fc1(x))x = self.relu(x)x = self.bn2(self.fc2(x))x = self.relu(x)x = self.fc3(x)return xmodel = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练
def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')# 验证
def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')# 保存模型
def save_model():torch.save(model.state_dict(), 'mnist_fc_model.pt')# 预测
def predict(img_path):model = MyNet()model.load_state_dict(torch.load('mnist_fc_model.pt'))model.eval()img = Image.open(img_path).convert('L')transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])t_img = transform(img).unsqueeze(0)print(t_img.shape)with torch.no_grad():output = model(t_img)_, predicted = torch.max(output.data, 1)print(predicted.item())epochs = 5train(model, train_loader, epochs)
eval(model, eval_loader)save_model()img_path = './img/7.png'
predict(img_path)

2.使用全连接网络训练和验证CIFAR10数据集

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 数据准备
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(32 * 32 * 3, 1024)self.bn1 = nn.BatchNorm1d(1024)self.dropout1 = nn.Dropout(0.3)self.fc2 = nn.Linear(1024, 512)self.bn2 = nn.BatchNorm1d(512)self.dropout2 = nn.Dropout(0.3)self.fc3 = nn.Linear(512, 256)  # 增加第三层self.bn3 = nn.BatchNorm1d(256)self.fc4 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 32 * 32 * 3)x = self.dropout1(self.bn1(self.fc1(x)))x = self.relu(x)x = self.dropout2(self.bn2(self.fc2(x)))x = self.relu(x)x = self.bn3(self.fc3(x))x = self.relu(x)x = self.fc4(x)return xmodel = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:data, target = data.to(device), target.to(device)output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')epochs = 25train(model, train_loader, epochs)
eval(model, eval_loader)

思考:为什么CIFAR10数据集的准确率很低?

http://www.dtcms.com/a/291292.html

相关文章:

  • PostgreSQL SysCache RelCache
  • Java API (二):从 Object 类到正则表达式的核心详解
  • DevOps是什么?
  • Flutter中 Provider 的基础用法超详细讲解(一)
  • C++的“链”珠妙笔:list的编程艺术
  • JAVA序列化知识小结
  • mac终端设置代理
  • 拟合算法(1)
  • socket编程(UDP)
  • QGIS、ArcMap、ArcGIS Pro中的书签功能、场景裁剪
  • 本地部署Dify、Docker重装
  • 时序论文43 | WPMixer:融合小波分解的多分辨率长序列预测模型
  • Nginx配置proxy protocol代理获取真实ip
  • ubuntu远程桌面不好使
  • 修复echarts由4.x升级5.x出现地图报错echarts/map/js/china.js未找到
  • 卷积神经网络基本概念
  • 深度学习之参数初始化和损失函数(四)
  • 深入解析MIPI C-PHY (二)C-PHY三线魔术:如何用6种“符号舞步”榨干每一滴带宽?
  • 设计模式六:工厂模式(Factory Pattern)
  • C语言:20250721笔记
  • 在 Conda 中删除环境及所有安装的库
  • 《使用 IDEA 部署 Docker 应用指南》
  • Linux-rpm和yum
  • Shell脚本编程:从入门到精通的实战指南
  • 从零开始:用Python库轻松搭建智能AI代理
  • Djoser 详解
  • 深度学习中的数据增强:从理论到实践
  • hot100回归复习(算法总结1-38)
  • 力扣面试150(35/150)
  • 【安全篇 / 反病毒】(7.6) ❀ 01. 查杀HTTPS加密网站病毒 ❀ FortiGate 防火墙