使用 PyTorch 的 torchvision 库加载 CIFAR-10 数据集
CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
import torchvision
import torchvision.transforms as transforms# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 自动下载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', # 数据保存路径train=True,download=True, # 设置为True自动下载transform=transform
)# 自动下载测试集
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)
1. 导入必要的库
import torchvision import torchvision.transforms as transforms |
- torchvision:PyTorch 的视觉库,提供常用数据集、模型架构和图像转换工具。
- transforms:用于图像预处理的模块,如缩放、归一化等。
2. 定义数据预处理流程
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) |
- transforms.Compose:将多个预处理操作按顺序组合。
- transforms.ToTensor():
- 将 PIL 图像或 NumPy 数组(H×W×C,范围 0-255)转换为 PyTorch 张量(C×H×W,范围 0.0-0)。
- transforms.Normalize(mean, std):
- 对每个通道进行归一化:output = (input - mean) / std。
- 这里mean=(0.5, 0.5, 0.5)和std=(0.5, 0.5, 0.5)将像素值从[0.0, 1.0]映射到[-1.0, 1.0](例如,0.0→-1.0,1.0→1.0)。
3. 下载并加载训练集
trainset = torchvision.datasets.CIFAR10( root='./data', # 数据保存路径 train=True, # True表示训练集(50,000张) download=True, # 自动下载(如果数据不存在) transform=transform # 应用预处理 ) |
- torchvision.datasets.CIFAR10:CIFAR-10 数据集类,包含 10 个类别(如飞机、汽车、鸟类等)的 60,000 张 32×32 彩色图像。
- 参数说明:
- root='./data':数据将下载到当前目录的data文件夹中。
- train=True:加载训练集(50,000 张);若为False则加载测试集(10,000 张)。
- download=True:若数据不存在,自动从互联网下载(约 170MB)。
- transform=transform:对图像应用之前定义的预处理(转为张量并归一化)。
4. 下载并加载测试集
testset = torchvision.datasets.CIFAR10( root='./data', # 与训练集路径一致 train=False, # 加载测试集 download=True, # 自动下载 transform=transform # 应用相同的预处理 ) |
- 测试集与训练集结构相同,但用于模型评估,不参与训练。
5. 数据验证与使用
下载完成后,数据将存储在./data/cifar-10-batches-py目录中。你可以:
- 查看数据集大小:
print(len(trainset)) # 输出: 50000 print(len(testset)) # 输出: 10000 |
- 访问单个样本:
image, label = trainset[0] # 获取第一张图像及其标签 print(image.shape) # 输出: (3, 32, 32) print(label) # 输出: 6(对应类别索引) |
- 使用数据加载器批量处理数据:
from torch.utils.data import DataLoader trainloader = DataLoader(trainset, batch_size=32, shuffle=True) testloader = DataLoader(testset, batch_size=32, shuffle=False) |
注意事项
- 下载路径:
- 若指定路径(如./data)已存在 CIFAR-10 数据,download=True不会重复下载。
- 若路径错误或无写入权限,会抛出异常(如PermissionError)。
- 网络问题:
- 首次下载需联网,可能需要几分钟。若下载中断,可删除./data目录后重新运行。
- 数据预处理:
- 归一化参数mean和std通常根据数据集的统计特性设定。对于 CIFAR-10,常用(0.5, 0.5, 0.5)进行简单归一化。
- 若需要更精确的归一化,可计算数据集的真实均值和标准差(如mean=[0.4914, 0.4822, 0.4465],std=[0.2470, 0.2435, 0.2616])。
扩展应用
加载数据后,可用于训练 CNN 模型(如之前创建的SimpleCNN):
# 假设model已定义 from torch import nn, optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 训练循环 for epoch in range(5): # 训练5个轮次 for inputs, labels in trainloader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"Epoch {epoch+1} completed") |