Python 训练营打卡 Day 38
了解下CIFAR数据集,尝试获取其中一张图片
CIFAR数据集是计算机视觉领域常用的基准数据集,主要有两个版本:
1. CIFAR-10
- 包含10个类别的6万张32x32彩色图像
- 每个类别有6000张图像(5000训练+1000测试)
- 类别包括:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车
2. CIFAR-100
- 包含100个细粒度类别的6万张32x32彩色图像
- 每个类别有600张图像(500训练+100测试)
- 100个类别又分为20个超类(如"鱼"超类包含"鲑鱼"、"鲨鱼"等子类)
这两个数据集常用于:
- 图像分类任务基准测试
- 深度学习模型性能评估
- 计算机视觉算法研究
在PyTorch中可以通过torchvision.datasets.CIFAR10/CIFAR100加载:
from torchvision import datasets# 加载CIFAR-10
train_data = datasets.CIFAR10(root='./data', train=True, download=True)
test_data = datasets.CIFAR10(root='./data', train=False, download=True)# 加载CIFAR-100
train_data = datasets.CIFAR100(root='./data', train=True, download=True)
从CIFAR-10中获取一张图片的代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 数据预处理先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) # cifar10的均值和标准差,用于标准化
])
# 加载cifar-10数据集,如果没有会自动下载
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=transform
)# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签# 显示图片(需要先反标准化)
def imshow(img):img = img * torch.tensor([0.2470, 0.2435, 0.2616]).view(3,1,1) + torch.tensor([0.4914, 0.4822, 0.4465]).view(3,1,1) # 反标准化img = img.numpy().transpose((1, 2, 0)) # 从(C,H,W)转换为(H,W,C)plt.imshow(img)plt.show()# 显示图片和标签
print(f"随机索引: {sample_idx}")
print(f"标签: {train_dataset.classes[label]}")
imshow(image)