【项目实战】——深度学习.全连接神经网络
目录
1.使用全连接网络训练和验证MNIST数据集
2.使用全连接网络训练和验证CIFAR10数据集
1.使用全连接网络训练和验证MNIST数据集
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.bn1 = nn.BatchNorm1d(256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.bn2 = nn.BatchNorm1d(128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.bn1(self.fc1(x))x = self.relu(x)x = self.bn2(self.fc2(x))x = self.relu(x)x = self.fc3(x)return xmodel = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练
def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch: {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')# 验证
def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')# 保存模型
def save_model():torch.save(model.state_dict(), 'mnist_fc_model.pt')# 预测
def predict(img_path):model = MyNet()model.load_state_dict(torch.load('mnist_fc_model.pt'))model.eval()img = Image.open(img_path).convert('L')transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])t_img = transform(img).unsqueeze(0)print(t_img.shape)with torch.no_grad():output = model(t_img)_, predicted = torch.max(output.data, 1)print(predicted.item())epochs = 5train(model, train_loader, epochs)
eval(model, eval_loader)save_model()img_path = './img/7.png'
predict(img_path)
2.使用全连接网络训练和验证CIFAR10数据集
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 数据准备
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)# 定义网络结构
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.fc1 = nn.Linear(32 * 32 * 3, 1024)self.bn1 = nn.BatchNorm1d(1024)self.dropout1 = nn.Dropout(0.3)self.fc2 = nn.Linear(1024, 512)self.bn2 = nn.BatchNorm1d(512)self.dropout2 = nn.Dropout(0.3)self.fc3 = nn.Linear(512, 256) # 增加第三层self.bn3 = nn.BatchNorm1d(256)self.fc4 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 32 * 32 * 3)x = self.dropout1(self.bn1(self.fc1(x)))x = self.relu(x)x = self.dropout2(self.bn2(self.fc2(x)))x = self.relu(x)x = self.bn3(self.fc3(x))x = self.relu(x)x = self.fc4(x)return xmodel = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, epochs):model.train()for epoch in range(epochs):correct = 0for data, target in train_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()correct /= len(train_loader.dataset)print(f'Train Epoch: {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')def eval(model, eval_loader):model.eval()eval_loss = 0correct = 0with torch.no_grad():for data, target in eval_loader:data, target = data.to(device), target.to(device)output = model(data)eval_loss += criterion(output, target).item()_, predicted = torch.max(output.data, 1)correct += (predicted.eq(target)).sum().item()eval_loss /= len(eval_loader.dataset)acc = 100.0 * correct / len(eval_loader.dataset)print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')epochs = 25train(model, train_loader, epochs)
eval(model, eval_loader)
思考:为什么CIFAR10数据集的准确率很低?