当前位置: 首页 > news >正文

5.27 打卡

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np# 1. 定义数据预处理方式
# ToTensor: 将PIL Image或NumPy array转换为PyTorch Tensor (HWC -> CHW),并归一化到[0.0, 1.0]
# 对于显示图片,Normalize可以先不加,或者如果加了,显示前需要逆标准化。
# 为了直接显示原始像素值,我们这里只用ToTensor。
transform = transforms.Compose([transforms.ToTensor()
])# CIFAR-10 的类别名称,方便显示
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'
]# 2. 加载CIFAR-10训练集 (如果本地没有会自动下载)
# download=True 会自动处理下载
cifar10_dataset = torchvision.datasets.CIFAR10(root='./data',         # 数据下载和存放的根目录train=True,            # 加载训练集download=True,         # 如果没有本地文件则下载transform=transform    # 应用预处理
)print(f"CIFAR-10 训练集大小: {len(cifar10_dataset)}")# 3. 获取一张图片及其标签 (直接通过索引访问 Dataset)
# dataset[0] 会调用数据集的 __getitem__ 方法
image_tensor, label_id = cifar10_dataset[0]print(f"获取到的图像 Tensor 形状: {image_tensor.shape}") # (Channels, Height, Width) -> (3, 32, 32)
print(f"获取到的图像标签 ID: {label_id}")
print(f"获取到的图像标签名称: {cifar10_classes[label_id]}")# 4. 显示图片
# Matplotlib 的 imshow 需要图像为 (Height, Width, Channels) 格式
# PyTorch 的 Tensor 是 (Channels, Height, Width) 格式
# 所以需要使用 .permute() 进行维度转换
# .numpy() 将 Tensor 转换成 NumPy 数组
plt.imshow(image_tensor.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image - Class: {cifar10_classes[label_id]}")
plt.axis('off') # 不显示坐标轴
plt.show()# 尝试获取另一张图片,例如第 100 张
image_tensor_100, label_id_100 = cifar10_dataset[99] # 索引从0开始plt.imshow(image_tensor_100.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image - Class: {cifar10_classes[label_id_100]}")
plt.axis('off')
plt.show()print("\n--- 尝试使用 Dataloader 获取并显示一个批次的第一张图片 ---")
# 也可以通过Dataloader来获取图片,虽然对于获取单张图片有点“杀鸡用牛刀”
cifar10_loader = DataLoader(dataset=cifar10_dataset,batch_size=1, # 这里设置batch_size=1,方便直接取出第一张shuffle=False # 不打乱,保证每次取到的是同一张
)data_iter = iter(cifar10_loader)
batch_images, batch_labels = next(data_iter)# batch_images 的形状是 (batch_size, C, H, W) -> (1, 3, 32, 32)
# batch_labels 的形状是 (batch_size) -> (1)single_image_tensor = batch_images[0]
single_label_id = batch_labels[0].item() # .item() 将单个Tensor值转换为Python标量plt.imshow(single_image_tensor.permute(1, 2, 0).numpy())
plt.title(f"CIFAR-10 Image from Dataloader - Class: {cifar10_classes[single_label_id]}")
plt.axis('off')
plt.show()

相关文章:

  • MySQL问题:MySQL中使用索引一定有效吗?如何排查索引效果
  • 《Python基础》第1期:人生苦短,我用Python
  • 第四十七篇-Tesla P40+Qwen3-30B-A3B部署与测试
  • SD07_NVM的安装及相关操作
  • qiankun 子应用怎样通过 props拿到子应用【注册之后挂载之前】主应用中发生变更的数据
  • 6个月Python学习计划 Day 6 - 综合实战:学生信息管理系统
  • 【系分】论文模版
  • 开源酷炫大数据可视化大屏html+eacher 100+套
  • 2025 海外短剧 CPS 系统开发:技术驱动下的全球化内容分销新范式
  • Spark、Hadoop对比
  • Day04
  • cursor-stats 实时监控 Cursor IDE 的使用情况和订阅状态
  • 体现物联网环境下安全防护的紧迫性 :物联网环境下的个人信息安全:隐忧与防护之道
  • Linux升级内核回退到旧内核启动
  • 2025上半年软考系统架构设计师选择题试题与答案
  • spring4第2课-ioc控制反转-依赖注入,是为了解决耦合问题
  • springboot--实战--大事件--用户接口开发
  • TS.43规范-1
  • winsock对话设计框架
  • 全志V853 mpp程序开发
  • 武汉网站建设开发/免费自建网站有哪些
  • 淮安网站优化/培训心得体会范文500字
  • 西宁网站设计/网页推广平台
  • 网站建设公司怎么算专业/时事新闻最新消息
  • 网站 流量 不够用/泉州百度seo
  • 临沂做网站哪家好/百度seo文章