5.27 打卡
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np# 1. 定义数据预处理方式
# ToTensor: 将PIL Image或NumPy array转换为PyTorch Tensor (HWC -> CHW),并归一化到[0.0, 1.0]
# 对于显示图片,Normalize可以先不加,或者如果加了,显示前需要逆标准化。
# 为了直接显示原始像素值,我们这里只用ToTensor。
transform = transforms.Compose([transforms.ToTensor()
])# CIFAR-10 的类别名称,方便显示
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'
]# 2. 加载CIFAR-10训练集 (如果本地没有会自动下载)
# download=True 会自动处理下载
cifar10_dataset = torchvision.datasets.CIFAR10(root='./data', # 数据下载和存放的根目录train=True, # 加载训练集download=True, # 如果没有本地文件则下载transform=transform # 应用预处理
)print(f"CIFAR-10 训练集大小: {len(cifar10_dataset)}")# 3. 获取一张图片及其标签 (直接通过索引访问 Dataset)
# dataset[0] 会调用数据集的 __getitem__ 方法
image_tensor, label_id = cifar10_dataset[0]print(f"获取到的图像 Tensor 形状: {image_tensor.shape}") # (Channels, Height, Width) -> (3, 32, 32)
print(f"获取到的图像标签 ID: {label_id}")
print(f"获取到的图像标签名称: {cifar10_classes[label_id]}")# 4. 显示图片
# Matplotlib 的 imshow 需要图像为 (Height, Width, Channels) 格式
# PyTorch 的 Tensor 是 (Channels, Height, Width) 格式
# 所以需要使用 .permute() 进行维度转换
# .numpy() 将 Tensor 转换成 NumPy 数组
plt.imshow(image_tensor.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image - Class: {cifar10_classes[label_id]}")
plt.axis('off') # 不显示坐标轴
plt.show()# 尝试获取另一张图片,例如第 100 张
image_tensor_100, label_id_100 = cifar10_dataset[99] # 索引从0开始plt.imshow(image_tensor_100.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image - Class: {cifar10_classes[label_id_100]}")
plt.axis('off')
plt.show()print("\n--- 尝试使用 Dataloader 获取并显示一个批次的第一张图片 ---")
# 也可以通过Dataloader来获取图片,虽然对于获取单张图片有点“杀鸡用牛刀”
cifar10_loader = DataLoader(dataset=cifar10_dataset,batch_size=1, # 这里设置batch_size=1,方便直接取出第一张shuffle=False # 不打乱,保证每次取到的是同一张
)data_iter = iter(cifar10_loader)
batch_images, batch_labels = next(data_iter)# batch_images 的形状是 (batch_size, C, H, W) -> (1, 3, 32, 32)
# batch_labels 的形状是 (batch_size) -> (1)single_image_tensor = batch_images[0]
single_label_id = batch_labels[0].item() # .item() 将单个Tensor值转换为Python标量plt.imshow(single_image_tensor.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image from Dataloader - Class: {cifar10_classes[single_label_id]}")
plt.axis('off')
plt.show()