Python打卡训练营打卡记录day38
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np # 用于图像转换# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 1. 数据预处理(CIFAR-10的3通道参数)
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10的3通道均值和标准差
])# 2. 加载CIFAR10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True, # 若本地已有数据,会跳过下载transform=transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=transform
)# 随机选择一张图片
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx] # 可视化函数(CIFAR-10彩色图)
def imshow(img):# 反标准化(3通道)img = img * torch.tensor([0.2023, 0.1994, 0.2010]).view(3,1,1) + torch.tensor([0.4914, 0.4822, 0.4465]).view(3,1,1)img = img.permute(1, 2, 0).numpy() # 调整通道顺序为HWCplt.imshow(img)plt.show()# CIFAR-10的类别名称
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']print(f"Label: {class_names[label]}")
imshow(image)# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True
)test_loader = DataLoader(test_dataset,batch_size=1000
)
@浙大疏锦行