当前位置: 首页 > news >正文

基于PyTorch的CIFAR10加载与TensorBoard可视化实践

视频学习来源:https://www.bilibili.com/video/BV1hE411t7RN?t=1.1&p=15

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom test_03 import writer# 添加 添加 download=True 参数来下载数据集
test_data = torchvision.datasets.CIFAR10(root=".dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True  # 新增此行,用于下载数据集
)test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False
)img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):step = 0for data in test_loader:imgs, targets = data# print(imgs.shape)# print(targets)writer.add_image("Epoch:{}".format(epoch), imgs, step)step =step +1writer.close()

这段代码结合了 PyTorch 的数据加载、图像处理和 TensorBoard 可视化功能,是深度学习中数据预处理和可视化的典型流程。


一、整体功能概览

这段代码的核心作用是:

  1. 加载 CIFAR10 测试数据集
  2. 用 DataLoader 按批次组织数据
  3. 通过 TensorBoard 可视化不同批次的图像数据
  4. 对比不同训练轮次(epoch)的数据分布

二、逐行代码详解

1. 导入库

import torchvision  # 计算机视觉工具库
from torch.utils.data import DataLoader  # 数据加载工具
from torch.utils.tensorboard import SummaryWriter  # TensorBoard可视化工具# 从test_03.py文件导入writer(这里实际被后面的代码覆盖了,暂时忽略)
from test_03 import writer

基础知识拓展

  • torchvision:PyTorch 官方的计算机视觉库,包含常用数据集(如 CIFAR10、MNIST)、预训练模型(如 ResNet、VGG)和图像预处理工具。
  • DataLoader:PyTorch 的核心数据加载工具,负责将数据集按批次加载,支持并行处理和数据打乱。
  • SummaryWriter:TensorBoard 的 PyTorch 接口,用于记录和可视化训练过程(图像、损失值、权重分布等)。

2. 加载 CIFAR10 测试数据集

test_data = torchvision.datasets.CIFAR10(root=".dataset",  # 数据集保存路径train=False,      # 是否为训练集(False表示测试集)transform=torchvision.transforms.ToTensor(),  # 数据转换download=True     # 自动下载数据集
)

参数详解

  • root:数据集存储的本地路径(这里是当前文件夹下的.dataset文件夹)。如果该路径不存在,会自动创建。
  • train=False:CIFAR10 分为训练集(50000 张图片)和测试集(10000 张图片),train=False表示加载测试集。
  • transform=ToTensor():将图像从 PIL 格式(Python 图像库格式)转换为 PyTorch 的 Tensor 格式,同时完成两个关键操作:
    • 像素值从[0, 255]归一化到[0, 1](神经网络对小范围数值更敏感)
    • 维度从(高度, 宽度, 通道)转换为(通道, 高度, 宽度)(PyTorch 的默认格式)
  • download=True:如果root路径下没有数据集,自动从官方地址下载(约 160MB)。

CIFAR10 数据集细节

  • 包含 10 个类别:飞机(0)、汽车(1)、鸟(2)、猫(3)、鹿(4)、狗(5)、青蛙(6)、马(7)、船(8)、卡车(9)。
  • 每张图片都是 32x32 像素的彩色图像(RGB 三通道)。

3. 创建 DataLoader

test_loader = DataLoader(dataset=test_data,    # 传入数据集batch_size=64,        # 每批次64个样本shuffle=True,         # 打乱数据顺序num_workers=0,        # 单进程加载(Windows推荐0)drop_last=False       # 保留最后一个不完整批次
)

核心作用:将test_data这个数据集转换为可迭代的批次数据,方便模型批量处理。

参数详解

  • dataset:要加载的数据集(必须是Dataset类的实例)。
  • batch_size=64:每次迭代返回 64 张图片和对应的 64 个标签。为什么用批次?
    • 单次处理太多样本会占用过多内存
    • 批次梯度下降比单样本梯度下降更稳定
  • shuffle=True:每个 epoch(轮次)前打乱数据顺序。测试集一般不需要打乱(设为False),这里可能是为了演示效果。
  • num_workers=0:数据加载的进程数。0 表示在主进程中加载(Windows 系统设为非 0 可能会报错),Linux/Mac 可设为 4、8 等加速加载。
  • drop_last=False:如果数据集总样本数不能被batch_size整除,是否丢弃最后一个不完整的批次。例如 CIFAR10 测试集有 10000 张,10000 ÷ 64 = 156 余 16,drop_last=False会保留最后 16 张的批次。

4. 查看单个样本

img, target = test_data[0]  # 获取第1个样本(索引从0开始)
print(img.shape)  # 输出图片形状
print(target)     # 输出标签

输出解释

  • img.shape的结果是 torch.Size([3, 32, 32]),表示:
    • 3:RGB 三通道
    • 32:图像高度(像素)
    • 32:图像宽度(像素)
  • target的结果是一个整数(例如 3),对应 CIFAR10 的类别标签(3 代表 "猫")。


    for data in test_loader:imgs, targets = data# print(imgs.shape)# print(targets)writer.add_image("Epoch:{}".format(epoch), imgs, step)step =step +1

为什么用test_data[0]而不是test_loader[0]

  • test_dataDataset对象,支持按索引直接访问单个样本。
  • test_loaderDataLoader对象,不支持直接按索引访问,必须通过迭代器(for循环)访问。

5. TensorBoard 可视化设置

writer = SummaryWriter("dataloader")  # 创建日志写入器,日志保存到"dataloader"文件夹

TensorBoard 作用

  • 实时可视化训练过程中的图像、损失值、准确率等指标。
  • 支持对比不同实验的结果(如不同批次大小、不同学习率的效果)。

使用方法

  1. 代码运行后,会在当前目录生成 "dataloader" 文件夹,里面包含日志文件。
  2. 打开终端,运行命令:tensorboard --logdir=dataloader
    tensorboard --logdir=dataloader
  3. 在浏览器中访问提示的地址(通常是http://localhost:6006),即可查看可视化结果。

6. 多轮次可视化批次数据

for epoch in range(2):  # 循环2个轮次step = 0  # 记录每个轮次内的步数for data in test_loader:  # 迭代加载批次数据imgs, targets = data  # 拆分批次数据为图片和标签# 向TensorBoard写入图像,标签为"Epoch:0"或"Epoch:1",步数为stepwriter.add_image("Epoch:{}".format(epoch), imgs, step)step += 1  # 步数递增writer.close()  # 关闭写入器,释放资源

核心逻辑

  • 循环 2 个 epoch(轮次),模拟模型训练时多轮次处理数据的场景。
  • 每个 epoch 内,通过test_loader按批次加载数据,并用writer.add_image将批次图像写入 TensorBoard。

add_image参数详解

  • 第 1 个参数:图像标签(字符串),用于在 TensorBoard 中区分不同类别的图像。这里用"Epoch:0""Epoch:1"区分两个轮次。
  • 第 2 个参数:要显示的图像数据,必须是 Tensor 格式,形状可以是:
    • 单张图片:(通道数, 高度, 宽度)
    • 批次图片:(批次大小, 通道数, 高度, 宽度)(这里用的是这种格式,会自动显示网格状排列的多张图片)
  • 第 3 个参数:全局步数(step),用于在 TensorBoard 中按顺序展示。

为什么要分多个 epoch?

  • 在模型训练中,一个 epoch 表示遍历完所有训练数据一次。
  • 通常需要多个 epoch 才能让模型充分学习数据中的规律(如 10、20、50 个 epoch)。
  • 这里可视化不同 epoch 的批次数据,是为了观察数据打乱后的分布差异(因为shuffle=True)。

三、运行结果与 TensorBoard 查看

1. 控制台输出

torch.Size([3, 32, 32])  # 第1张图片的形状
3                       # 第1张图片的标签(例如"猫")

2. TensorBoard 可视化

打开 TensorBoard 后,在 "IMAGES" 标签页可以看到:

  • 两个类别:Epoch:0Epoch:1
  • 每个类别下有 157 张网格图片(因为 10000 ÷ 64 = 156 余 16,共 157 个批次)
  • 每张网格图包含 64 张(或最后一批 16 张)32x32 的彩色图像
  • 对比Epoch:0Epoch:1的同一 step,会发现图像顺序不同(因为shuffle=True

四、关键知识点拓展

1. Dataset 与 DataLoader 的关系

  • Dataset:负责 “数据在哪”“怎么读”(存储数据路径、定义单样本读取方式)。
  • DataLoader:负责 “怎么喂给模型”(批处理、打乱、并行加载)。
  • 类比:Dataset像仓库管理员(负责找到并取出单个商品),DataLoader像快递员(负责把商品打包成批,高效配送)。

2. 为什么需要 TensorBoard?

  • 深度学习训练周期长,需要实时监控模型状态。
  • 可以直观对比不同参数(如学习率、批次大小)对结果的影响。
  • 支持可视化图像、损失曲线、模型结构、梯度分布等,帮助调试模型。

3. 常见错误与解决

  • 数据集下载失败:检查网络连接,或手动下载数据集放到root路径下。
  • num_workers 报错:Windows 系统将num_workers设为 0(多进程在 Windows 上支持不好)。
  • TensorBoard 无法打开:确保日志路径正确,或尝试更换端口(tensorboard --logdir=dataloader --port=6007)。

五、实际应用场景

这段代码是深度学习的基础流程,实际训练时会在此基础上添加:

  1. 定义模型(如 CNN、ResNet)
  2. 定义损失函数(如交叉熵损失)
  3. 定义优化器(如 Adam、SGD)
  4. 在循环中加入模型训练逻辑(前向传播→计算损失→反向传播→参数更新)
  5. 用 TensorBoard 记录损失值、准确率等指标
http://www.dtcms.com/a/418534.html

相关文章:

  • 西安网站建设陕icp网站建设公司考察
  • Linux中安装es
  • flink批处理-水位线
  • Unity单元测试:C语言轻量级框架实战
  • 网站怎么做搜索引擎优化、中建官网
  • 构建并运行最小 Linux 内核
  • 粤港澳全运会网络安全防御体系深度解析:威胁态势与实战防护
  • 数据结构——包装类泛型
  • 中国建设银行贵州分行网站安卓app制作入门教程
  • 17. 整个网站建设中的关键是专业客户管理系统
  • RuoYi 学习笔记 2:常用功能
  • 负载均衡式的在线OJ项目编写(五)
  • USBKey智能密码钥匙:从硬件安全到未来信任架构的深度技术解析
  • K8s日志架构:Sidecar容器实践指南
  • 前端开发,iframe 相关经验总结
  • 前端-JS基础-day3
  • MIT 6.S081 文件系统的崩溃恢复
  • 图片展示模块网站做一个多少钱影视vip网站建设教程
  • 环境搭建,Ubuntu 安装、客户端使用与性能认知
  • 合肥市城乡和建设网站南充建设企业网站
  • Music Muse AI音乐生成器全面解析:免费创作高质量音乐的核心要素
  • Go 语言中的结构体
  • Nest 文件上传与下载
  • 2025-9-28学习笔记
  • 深度学习(十三):向量化与矩阵化
  • 矩阵结构体 图片绘制 超级玛丽demo6
  • 承接网站开发 app开发学校网站建设责任书
  • 网站 管理检察内网门户网站建设
  • LeetCode 390 消除游戏
  • 汕头seo建站新品发布会的作用