使用全连接神经网络训练和预测MNIST以及CIFAR10
深度学习 — 使用全连接神经网络训练和预测 MNIST以及CIFAR10
文章目录
- 深度学习 --- 使用全连接神经网络训练和预测 MNIST以及CIFAR10
- 一,使用全连接神经网络训练和预测 MNIST
- 二,使用全连接神经网络训练和预测CIFAR10
- 三,所需链接
- ✅ MNIST 数据集
- ✅ CIFAR-10 数据集
一,使用全连接神经网络训练和预测 MNIST
'''
使用全连接神经网络训练和预测MNIST手写数字数据集
1.数据准备,通过数据集加载官方提供的MNIST数据集
2.构建神经网络
3.实现训练方法,使用交叉熵损失函数和Adam优化器
4.实现验证方法
5.通过测试图片预测结果
'''import torch
from torchvision import datasets,transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import osdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def build_data():transform =transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor(),#转换为张量])#训练数据集dataset=datasets.MNIST(root="./datasets",train=True,download=True,transform=transform)#测试数据集val_dataset=datasets.MNIST(root="./datasets",train=False,download=True,transform=transform)#构建训练数据加载器train_dataloader=DataLoader(dataset,batch_size=64,shuffle=True)#构建测试数据加载器val_dataloader=DataLoader(val_dataset,batch_size=64,shuffle=True)return train_dataloader,val_dataloader#构建神经网络
class MNISTNet(nn.Module):def __init__(self,in_features,out_features):super().__init__()self.fc1=nn.Linear(in_features,128)self.bn1=nn.BatchNorm1d(128)self.relus=nn.ReLU()self.fc2=nn.Linear(128,64)self.bn2=nn.BatchNorm1d(64)self.relus=nn.ReLU()self.fc3=nn.Linear(64,out_features)def forward(self,x):x = x.view(x.size(0), -1)x=self.fc1(x)x=self.bn1(x)x=self.relus(x)x=self.fc2(x)x=self.bn2(x)x=self.relus(x)x=self.fc3(x)return x#训练
def train(model,train_dataloader,val_dataloader,lr,epochs):model.train()#损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr,betas=(0.9,0.999),eps=1e-08,weight_decay=0.01)for epoch in range(epochs):correct=0for tx,ty in train_dataloader:tx, ty = tx.to(device), ty.to(device)y_pred=model(tx)loss=criterion(y_pred,ty)optimizer.zero_grad()loss.backward()optimizer.step()_,pred=torch.max(y_pred.data,dim=1)correct+=(pred==ty).sum().item()acc=correct/len(train_dataloader.dataset)print(f'Epoch:{epoch+1}/{epochs},loss:{loss.item():.4f},acc:{acc:.4f}')#验证
def eval(model,val_dataloader):model.eval()crirerion=nn.CrossEntropyLoss()correct=0for vx,vy in val_dataloader:vx,vy = vx.to(device), vy.to(device)with torch.no_grad():y_pred=model(vx)loss=crirerion(y_pred,vy)_,pred=torch.max(y_pred,dim=1)correct+=(pred==vy).sum().item()acc=correct/len(val_dataloader.dataset)acc=correct/len(val_dataloader.dataset)print(f'loss:{loss.item():.4f},acc:{acc:.4f}')#模型保存
def save_model(model,path):torch.save(model.state_dict(), path)
#模型加载
def load_model(path):model = MNISTNet(1*28*28, 10)model.load_state_dict(torch.load(path, map_location=device, weights_only=True))return model#测试
# 修改 predict 函数:确保模型和输入在同一设备
def predict(model, test_img, model_path):transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),])image = Image.open(test_img).convert('L')#转换为灰度图t_img = transform(image).unsqueeze(0).to(device) # 直接放到目标设备model = load_model(model_path)#加载模型model.to(device)# 显式把模型移到同一设备model.eval()#测试模式with torch.no_grad():y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'预测结果为:{pred.item()}')if __name__=='__main__':train_dataloader,val_dataloader=build_data()#获取数据model=MNISTNet(1*28*28,10).to(device)#构建模型train(model,train_dataloader,val_dataloader,lr=0.001,epochs=2)#训练eval(model,val_dataloader)#验证os.makedirs('./model', exist_ok=True)save_model(model,'./model/mnist_model.pt')#模型保存predict(model,'torch-fcnn/fcnn-demo/datasets/3.png','./model/mnist_model.pt')#测试
二,使用全连接神经网络训练和预测CIFAR10
import torch
from torchvision import datasets,transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from PIL import Image
import osdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def build_data():#转换为张量transform=transforms.Compose([transforms.Resize((32,32)),transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(),#转换为张量transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])#加载训练cifar-10数据集dataset=datasets.CIFAR10(root="./datasets",train=True,download=True,transform=transform)#测试数据集val_dataset=datasets.CIFAR10(root="./datasets",train=False,download=True,transform=transform)#构建训练数据加载器train_dataloader=DataLoader(dataset,batch_size=64,shuffle=True)#构建测试数据加载器val_dataloader=DataLoader(val_dataset,batch_size=64,shuffle=True)return train_dataloader,val_dataloader#构建神经网络模型
class CIFAR10Net(nn.Module):def __init__(self,in_features,out_features):super().__init__()self.fc1=nn.Linear(3072,512)self.bn1=nn.BatchNorm1d(512)self.relu1=nn.ReLU()self.fc2=nn.Linear(512,256)self.bn2=nn.BatchNorm1d(256)self.relu2=nn.ReLU()self.fc3=nn.Linear(256,128)self.bn3=nn.BatchNorm1d(128)self.relu3=nn.ReLU()self.fc4=nn.Linear(128,64)self.bn4=nn.BatchNorm1d(64)self.relu4=nn.ReLU()self.fc5=nn.Linear(64,out_features)def forward(self,x):x=x.view(x.size(0),-1)x=self.fc1(x)x=self.bn1(x)x=self.relu1(x)x=self.fc2(x)x=self.bn2(x)x=self.relu2(x)x=self.fc3(x)x=self.bn3(x)x=self.relu3(x)x=self.fc4(x)x=self.bn4(x)x=self.relu4(x)x=self.fc5(x)return x#训练
def train(model,train_dataloader,val_dataloader,lr,epochs):model.train()#损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr,betas=(0.9,0.999),eps=1e-08,weight_decay=0.01)for epoch in range(epochs): correct=0for tx,ty in train_dataloader:tx, ty = tx.to(device), ty.to(device)y_pred=model(tx)loss=criterion(y_pred,ty)optimizer.zero_grad()loss.backward()optimizer.step()_,pred=torch.max(y_pred.data,dim=1)correct+=(pred==ty).sum().item()acc=correct/len(train_dataloader.dataset)print(f'Epoch:{epoch+1}/{epochs},loss:{loss.item():.4f},acc:{acc:.4f}')#验证
def eval(model, val_dataloader):model.eval()criterion = nn.CrossEntropyLoss()total_loss = 0.0correct = 0total_samples = 0with torch.no_grad():for vx, vy in val_dataloader:vx, vy = vx.to(device), vy.to(device)y_pred = model(vx)loss = criterion(y_pred, vy)total_loss += loss.item() * vx.size(0) _, pred = torch.max(y_pred, dim=1)correct += (pred == vy).sum().item()total_samples += vy.size(0)avg_loss = total_loss / total_samplesacc = correct / total_samplesprint(f'Val Loss: {avg_loss:.4f}, Val Acc: {acc:.4f}')return avg_loss #模型保存
def save_model(model,path):torch.save(model.state_dict(),path)#模型加载
def load_model(path):model=CIFAR10Net(1*32*32,10).to(device)model.load_state_dict(torch.load(path,map_location=device,weights_only=True))model.eval()return model#测试模型
def predict(model,test_img,model_path):transform=transforms.Compose([transforms.Resize((32,32)),transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])image = Image.open(test_img).convert('RGB')t_img=transform(image).unsqueeze(0).to(device)model=load_model(model_path)model.to(device)model.eval()classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']with torch.no_grad(): y_pred=model(t_img)_,pred=torch.max(y_pred,dim=1)print(f'预测类别编号:{pred.item()},类别名称:{classes[pred.item()]}')if __name__ == '__main__':train_dataloader,val_dataloader=build_data()model=CIFAR10Net(1*32*32,10).to(device)train(model,train_dataloader,val_dataloader,lr=0.0001,epochs=50)eval(model,val_dataloader)os.makedirs('models',exist_ok=True)save_model(model,'models/cifar10_model.pt')predict(model,'torch-fcnn/fcnn-demo/datasets/100.jpg','models/cifar10_model.pt')
三,所需链接
以下是 MNIST 和 CIFAR-10 数据集的直接下载链接,均为官方或常用可信来源:
✅ MNIST 数据集
下载方式 | 链接 |
---|---|
官方页面(含所有文件) | https://yann.lecun.com/exdb/mnist/ |
百度网盘(JPEG格式) | https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n 提取码:xl8n |
百度网盘(原始格式) | https://pan.baidu.com/s/1jAPlVKLYamJn6I63GD6HDg?pwd=azq2 提取码:azq2 |
✅ CIFAR-10 数据集
下载方式 | 链接 |
---|---|
官方压缩包(Python版) | http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz |
官方主页(含多种版本) | http://www.cs.toronto.edu/~kriz/cifar.html |
如需快速使用,可直接用 PyTorch 或 TensorFlow 提供的 API 自动下载(首次运行会自动缓存到本地):
# PyTorch 示例
from torchvision.datasets import MNIST, CIFAR10
MNIST(root='./data', train=True, download=True)
CIFAR10(root='./data', train=True, download=True)
建议使用API下载