python第51天
1.读取数据
使用CIFAR-10图像数据
import torch
from torchvision import datasets, transforms# 定义图像预处理流程
image_transform = transforms.Compose([transforms.ToTensor(), # 将PIL图像转换为张量transforms.Normalize(mean=(0.5, 0.5, 0.5), # RGB三通道均值std=(0.5, 0.5, 0.5)) # RGB三通道标准差
])# 获取训练数据集
trainset = datasets.CIFAR10(root='./data', # 数据集存储路径train=True, # 使用训练集transform=image_transform,download=True # 如果本地不存在则下载
)# 获取测试数据集
testset = datasets.CIFAR10(root='./data',train=False, # 使用测试集transform=image_transform,download=True
)# 配置数据加载器
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=128, # 每批样本数量shuffle=True # 训练时打乱顺序
)test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=128,shuffle=False # 测试时保持原始顺序
)
2.模型建立
(1)建立CNN模型
import torch
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc1 = nn.Linear(32 * 8 * 8, 256)self.fc2 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x))) # 16x16x16x = self.pool(self.relu(self.conv2(x))) # 32x8x8x = x.view(-1, 32 * 8 * 8)x = self.relu(self.fc1(x))x = self.fc2(x)return x
@浙大疏锦行