当前位置: 首页 > news >正文

2025-05-21 Python深度学习5——数据读取

文章目录

  • 1 数据准备
  • 2 `Dataset`
    • 2.1 自定义 Dataset
    • 2.2 使用示例
  • 3 `TensorBoard`
    • 3.1 安装
    • 3.2 标量可视化(Scalars)
    • 3.3 图像可视化(Images)
    • 3.4 其他常用功能
  • 4 `transform`
    • 4.1 `ToTensor()`
    • 4.2 `Normalize()`
    • 4.3 `Resize()`
    • 4.4 `Compose()`
    • 4.5 Crop
      • `RandomCrop()`
      • `CenterCrop()`
      • `RandomResizedCrop()`
      • `FiveCrop() / TenCrop()`
    • 4.6 `RandomHorizontalFlip() / RandomVerticalFlip()`
    • 4.7 Random
      • `RandomChoice()`
      • `RandomApply()`
      • `RandomOrder()`
    • 4.8 自定义 transforms
    • 4.9 其他
  • 5 torchvision 中的数据集
    • 5.1 CIFAR10 数据集
    • 5.2 数据加载
    • 5.3 数据格式
    • 5.4 可视化
  • 6 `DataLoader`
    • 6.1 读取数据
    • 6.2 可视化

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

1 数据准备

​ 数据集放在 dataset 路径下,其中 train 文件夹存放训练数据,包括 ants 和 bees,val 文件夹存放测试数据,包括 ants 和 bees。

image-20250520004935472

​ 训练数据为若干个 .jpg 图片。

image-20250520005158000
  • 数据集下载链接:https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA;密码: 5suq。
  • 参考教程视频链接:P7. TensorBoard的使用(一)_哔哩哔哩_bilibili。

2 Dataset

​ Dataset 是抽象类,定义数据来源及单样本读取逻辑,所有自定义的 Dataset 均需要继承它。

​ 必须实现的方法:

  • __getitem__(self, index):根据索引返回(样本, 标签)。
  • __len__(self):返回数据集大小。

2.1 自定义 Dataset

import os
from PIL import Image
from torch.utils.data import Datasetclass MyData(Dataset):def __init__(self, root_dir, label_dir):# 初始化函数,传入根目录和标签目录self.root_dir = root_dir  # 数据集根目录self.label_dir = label_dir  # 标签目录(也是类别名)# 拼接完整路径self.path = os.path.join(self.root_dir, self.label_dir)# 获取该路径下所有图片文件名self.img_path = os.listdir(self.path)def __getitem__(self, index):# 根据索引获取图片名img_name = self.img_path[index]# 拼接图片完整路径img_item_path = os.path.join(self.path, img_name)# 使用PIL打开图片img = Image.open(img_item_path)# 标签就是目录名(如"ants"或"bees")label = self.label_dirreturn img, label  # 返回图片和标签def __len__(self):# 返回数据集大小(图片数量)return len(self.img_path)
  1. __init__()
  • 接收两个参数:root_dir(根目录)和label_dir(标签目录)。

  • 使用os.path.join拼接完整路径。

  • os.listdir获取该目录下所有文件名列表。

  1. __getitem__()

    • 根据索引获取对应图片。

    • 使用 PIL.Image 打开图片。

    • 标签直接使用目录名(简单示例中)。

    • 返回(图片, 标签)元组。

  2. __len__()

    • 返回数据集大小(图片数量)。

    • 用于 DataLoader 确定迭代次数。

2.2 使用示例

# 定义路径
root_dir = "dataset/train"  # 训练集根目录
ants_label_dir = "ants"    # 蚂蚁图片目录
bees_label_dir = "bees"    # 蜜蜂图片目录# 创建数据集实例
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)# 合并数据集(实际应该使用ConcatDataset)
train_dataset = ants_dataset + bees_dataset  # 简单合并

​ 示例中直接使用+合并数据集不是很规范,PyTorch 提供ConcatDataset类来正确合并多个 Dataset:

from torch.utils.data import ConcatDataset
train_dataset = ConcatDataset([ants_dataset, bees_dataset])

3 TensorBoard

​ TensorBoard 是 TensorFlow 提供的可视化工具,PyTorch 通过torch.utils.tensorboard模块也支持使用 TensorBoard。它可以帮助我们:

  • 监控训练过程中的指标变化(如损失、准确率)。
  • 可视化模型结构。
  • 查看图像、音频等数据。
  • 分析参数分布和直方图。

3.1 安装

​ 使用以下命令进行安装:

pip install tensorboard

​ 基本使用流程如下:

  1. 创建SummaryWriter实例。
  2. 使用各种add_*方法记录数据。
  3. 关闭SummaryWriter。
  4. 启动 TensorBoard 服务查看结果。

3.2 标量可视化(Scalars)

add_scalar(tag, scalar_value, global_step)

  • tag: 数据标识(图表标题)。
  • scalar_value: 要记录的标量值。
  • global_step: 训练步数/迭代次数。
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")  # 日志保存到logs目录for i in range(100):writer.add_scalar("y=2x", 2*i, i)  # 记录y=2x函数值writer.close()

​ 上述代码运行后,在当前目录下会生成 logs 目录,里面存放 TensorBoard 的可视化数据。

image-20250520011221188

​ 进入 Pycharm 终端,在命令行运行如下命令:

tensorboard --logdir=Pytorch教程/logs --port=6006
  • logdir:可视化数据的存储路径在哪里。
  • port:从哪个端口打开可视化网页。
image-20250520011428889

​ 在浏览器访问http://localhost:6006即可查看可视化结果。

image-20250520011635433

注意

​ 请确保从 cmd 中输入命令,而不是 Powershell。

​ 若出现如下报错,说明使用 Powershell 打开而不是 cmd。

image-20250520011743283

解决方案

​ 打开 Pycharm 的设置选项,进入“终端”页面,将默认页签更改为 cmd.exe,重新在 Pycharm 中打开终端即可。

image-20250520011933512

3.3 图像可视化(Images)

add_image(tag, img_tensor, global_step, dataformats)

  • tag: 数据标识(图表标题)。

  • img_tensor: 图像数据(numpy 数组或 torch tensor)。

  • global_step: 训练步数/迭代次数。

  • dataformats: 指定数据格式,如 ‘HWC’(高度、宽度、通道)。

from PIL import Image
import numpy as npimg_PIL = Image.open("image.jpg")
img_array = np.array(img_PIL)  # 转换为numpy数组writer.add_image("train", img_array, 1, dataformats='HWC')

​ 上述代码运行后,在可视化网页中点击刷新,即可显示图像。

image-20250520012444118

3.4 其他常用功能

  • add_graph(): 可视化模型结构。
  • add_histogram(): 记录参数分布。
  • add_text(): 记录文本信息。
  • add_embedding(): 可视化高维数据降维结果。

4 transform

torchvision.transforms 是 PyTorch 中用于图像预处理的强大工具包,提供丰富的图像转换操作。主要功能包括:

  • 图像格式转换(如 PIL Image ↔ Tensor)。
  • 图像尺寸调整(缩放、裁剪)。
  • 数据增强(翻转、旋转、颜色变换)。
  • 数据标准化。
image-20250520031711124

4.1 ToTensor()

image-20250520032159875
  • 将 PIL/Numpy 图像转为 PyTorch Tensor。
  • 将像素值从 [0, 255] 缩放到 [0.0, 1.0]。
  • 调整维度顺序为 (C, H, W)。
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("Tensor_img", tensor_img)
print(tensor_img)  # 输出形状为(C, H, W)的tensor,值范围[0,1]writer.close()
image-20250520033534663

4.2 Normalize()

image-20250520041253221

​ 逐 Channel 对图像进行标准化。

  • mean:各通道的均值。

  • std:各通道的标准差。

  • inplace:是否原地操作。

  • 使用公式:output = (input - mean) / std

  • 此处将 [0, 1] 范围的数据转换到 [-1, 1] 范围。

  • 注意:必须先转换为 Tensor 才能使用 Normalize。

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)# Normalize
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
normalize_img = normalize(tensor_img)
writer.add_image("Normalize_img", normalize_img)
print(normalize_img)writer.close()
image-20250520033833209

4.3 Resize()

image-20250520034359685
  • 将图像调整为指定尺寸 (512x512)。
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)# Resize
resize = transforms.Resize((512, 512))
resize_img = resize(tensor_img)
writer.add_image("Resize_img", resize_img)
print(tensor_img)writer.close()
image-20250520034528424

4.4 Compose()

image-20250520034932135
  • 将多个转换步骤按顺序组合。
  • 执行顺序:先 Resize 再 ToTensor。
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# Compose
compose = transforms.Compose([transforms.Resize((512, 512)),transforms.ToTensor()
])
compose_img = compose(img)
writer.add_image("Compose_img", compose_img)
print(compose_img)writer.close()
image-20250520035024393

4.5 Crop

RandomCrop()

image-20250520041843554

​ 从图片中随机裁剪出尺寸为 size 的图片。

  • size:所需裁剪图片尺寸。
  • padding:设置填充大小。
    • 当为a时,上下左右均填充 a 个像素。
    • 当为 (a, b) 时,上下填充 b 个像素,左右填充 a 个像素。
    • 当为 (a, b, c, d) 时,左,上,右,下分别填充 a, b, c, d。
  • pad_if_needed:若图像小于设定 size,则填充。
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)# RandomCrop
random_crop = transforms.RandomCrop((512, 512))
random_crop_img = random_crop(tensor_img)
writer.add_image("RandomCrop_img", random_crop_img)
print(random_crop_img)writer.close()
image-20250520035251328

CenterCrop()

image-20250520042105453

​ 从图像中心裁剪图片。

size:所需裁剪图片尺寸。

RandomResizedCrop()

image-20250520042221052

​ 随机大小、长宽比裁剪图片。

  • size:所需裁剪图片尺寸。
  • scale:随机裁剪面积比例,默认 (0.08, 1)。
  • ratio:随机长宽比,默认(3/4, 4/3)。
  • interpolation:插值方法。
    • PIL.Image.NEAREST。
    • PIL.Image.BILINEAR。
    • PIL.Image.BICUBIC。

FiveCrop() / TenCrop()

image-20250520042512450 image-20250520042526938

​ FiveCrop 在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片;

​ TenCrop 对这 5 张图片进行水平或者垂直镜像获得 10 张图片.

  • size:所需裁剪图片尺寸。
  • vertical_flip:是否垂直翻转。

4.6 RandomHorizontalFlip() / RandomVerticalFlip()

image-20250520042719668

​ 依概率水平(左右)或垂直(上下)翻转图片。

  • p:翻转概率。
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformswriter = SummaryWriter("logs")
img = Image.open("image/pytorch.png")
print(img)# ToTensor
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)# RandomHorizontalFlip
random_flip = transforms.RandomHorizontalFlip(1)
random_flip_img = random_flip(tensor_img)
writer.add_image("RandomFlip_img", random_flip_img)
print(random_flip_img)writer.close()
image-20250520042812255

4.7 Random

RandomChoice()

​ 从一系列 transforms 方法中随机挑选一个。

transforms.RandomChoice([transforms1, transforms2, transforms3])

RandomApply()

​ 依据概率执行一组 transforms 操作。

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

RandomOrder()

​ 对一组 transforms 操作打乱顺序。

transforms.RandomOrder([transforms1, transforms2, transforms3])

4.8 自定义 transforms

​ 自定义 transforms 要素:

  1. __init__():初始化方法。
  2. __call__:执行方法。
class YourTransforms(object):def __init__(self, ...):...def __call__(self, img):...return img

RandomChoice为例:

image-20250520044226420

注意

  • 仅接收一个参数,返回一个参数。
  • 注意上下游的输出与输入。

4.9 其他

  • Pad():对图片边缘进行填充。

  • ColorJitter():调整亮度、对比度、饱和度和色相。

  • Grayscale() / RandomGrayscale():依概率将图片转换为灰度图。

  • RandomAffine():对图像进行仿射变换。

    仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转。

  • RandomErasing():对图像进行随机遮挡。

  • Lambda():用户自定义 Lambda 方法。

5 torchvision 中的数据集

​ 官方文档:Datasets — Torchvision 0.22 documentation。

image-20250521101319393

5.1 CIFAR10 数据集

​ 以 CIFAR10 为例:CIFAR10 — Torchvision 0.22 documentation。打开文档链接,以下是 CIFAR10 的创建方法。

​ 进入主页,可观看数据集详细介绍。

image-20250521103014615
  • 包含 60,000 张 32×32 像素的 RGB 彩色图像,分为 10 个类别(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车),每类 6,000 张。

    • 训练集

      50,000 张图像,分为 5 个批次(data_batch_1 至 data_batch_5),每批 10,000 张。每个类别在训练集中共 5,000 张图像,但单个批次内类别分布可能不均匀。

    • 测试集

      10,000 张图像(test_batch),每个类别均匀包含 1,000 张随机选择的图像,且与训练集无重叠。

  • 类别间完全互斥,例如“汽车”与“卡车”不重叠(汽车含轿车 / SUV,卡车仅含大型货车)。

5.2 数据加载

  • torchvision.datasets.CIFAR10

    PyTorch 内置的 CIFAR-10 数据集加载接口。

    • root='./dataset':数据集存储路径(若不存在会自动创建)。
    • train=True/False:分别加载训练集(50,000 张)或测试集(10,000 张)。
    • transform=None:接收 PIL 图像并返回转换后的版本的函数/转换。
    • target_transform=None:接收目标并对其进行转换的函数/转换。
    • download=True:自动下载数据集(若本地不存在)。
import torchvision# 定义数据集的转换方式
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
])# 加载训练集
train_set = torchvision.datasets.CIFAR10(root='./dataset',  # 数据集的根目录train=True,  # 是否为训练集transform=dataset_transform,  # 数据集的转换方式download=True  # 是否下载数据集
)
# 加载测试集
test_set = torchvision.datasets.CIFAR10(root='./dataset',  # 数据集的根目录train=False,  # 是否为训练集transform=dataset_transform,  # 数据集的转换方式download=True  # 是否下载数据集
)

5.3 数据格式

  • 未启用transform

    test_set[0]返回 PIL.Image 对象和标签(整数)。

    # 打印测试集的第一个样本
    print(test_set[0])# 获取测试集的第一个样本的图像和标签
    img, target = test_set[0]
    # 打印图像
    print(img)
    # 打印标签
    print(target)
    
    image-20250521105057749
  • 启用transform=ToTensor()

    图像转为 Tensor 格式(形状为 [3, 32, 32],值域 [0, 1])。

    # 获取测试集的第一个样本的图像和标签
    img, target = test_set[0]
    # 打印图像
    print(img.shape)
    # 打印标签
    print(target)
    
    image-20250521105254327
  • test_set.classes:直接输出 10 个类别的名称列表。

    # 打印测试集的类别
    print(test_set.classes)
    
    image-20250521105605127

5.4 可视化

# 创建一个SummaryWriter对象,用于记录训练过程中的数据
writer = SummaryWriter("logs")# 遍历测试集
for i in range(10):# 获取测试集中的第i个样本img, target = test_set[i]# 将第i个样本的图像添加到SummaryWriter中writer.add_image("test_set", img, i)writer.close()
image-20250521110026292

6 DataLoader

image-20250425181735945

​ 功能:构建可迭代的数据装载器。

​ 常用参数:

  • dataset:Dataset 类,决定数据从哪读取及如何读取。
  • batch_size:批大小。
  • shuffle:每个 Epoch 是否乱序。
  • num_workers:是否多进程读取数据,0 表示使用主线程读取。
  • drop_last:当样本数不能被 batch_size 整除时,是否舍弃最后一批数据。
  • Epoch:所有训练样本都已输入到模型中,称为一个 Epoch。

  • Iteration:一批样本输入到模型中,称之为一个 Iteration。

  • Batchsize:批大小,决定一个 Epoch 有多少个 Iteration。

样本总数:80,Batchsize:8

  • 1 Epoch 10 Iteration

样本总数:87,Batchsize:8

  • drop_last = True:1 Epoch = 10 Iteration
  • drop_last = False:1 Epoch = 11 Iteration

​ 其他参数

  • sampler:自定义采样策略(如按权重采样,与 shuffle 互斥)。
  • batch_sampler:直接生成批次索引(与 batch_size/shuffle/sampler 互斥)。
  • collate_fn:自定义批次合并逻辑(处理非规则数据,如变长序列)。
  • pin_memory:是否将数据复制到 CUDA 固定内存(加速 GPU 数据传输)。
  • timeout:从 worker 收集数据的超时时间(秒,0=无超时)。
  • worker_init_fn:每个 worker 的初始化函数(常用于设置随机种子)。
  • multiprocessing_context:多进程上下文(如 'spawn'/'fork',影响进程启动方式)。
  • generator:控制随机采样的随机数生成器(RNG)。
  • prefetch_factor:每个 worker 预加载的批次数量(默认 2,仅 num_workers>0 生效)。
  • persistent_workers:是否保持 worker 进程存活(避免重复初始化开销)。
  • pin_memory_device:指定固定内存的设备(如 'cuda:0')。
  • in_order:是否强制按 FIFO 顺序返回批次(num_workers>0 时生效)。

6.1 读取数据

​ 以 CIFAR10 测试集为例,共包含 10000 个样本。

import torchvision
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集
test_data = torchvision.datasets.CIFAR10(root='./dataset',  # 数据集的根目录train=False,  # 是否为训练集transform=torchvision.transforms.ToTensor(),  # 数据预处理
)

​ 设置batch_size=4,即每次读取 4 个样本,并且drop_last=False

# 创建数据加载器
test_loader = DataLoader(dataset=test_data,  # 数据集batch_size=4,  # 每个batch的大小shuffle=True,  # 是否打乱数据num_workers=0,  # 加载数据的线程数drop_last=False  # 是否丢弃最后一个batch
)

​ 遍历 test_loader 以获得每批次的数据,

i = 0
# 遍历test_loader中的数据
for data in test_loader:i += 1imgs, targets = data  # 获取每批次的数据与标签print(i, "  ", imgs.shape)print(targets)
image-20250521111845474
  • 数据集一共 10000 个样本,每批次取出 4 个,因此序号为 10000 / 4 = 2500。
  • 每个批次的张量形状为 torch.Size([4, 3, 32, 32]),是 4 个 torch.Size([3, 32, 32]) 的组合。
  • targets 为包含 4 个元素的列表,每个元素表示对应位置的数据标签。

6.2 可视化

​ 设置batch_size=64,并且drop_last=False

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 加载CIFAR10测试数据集
test_data = torchvision.datasets.CIFAR10(root='./dataset',  # 数据集的根目录train=False,  # 是否为训练集transform=torchvision.transforms.ToTensor(),  # 数据预处理
)# 创建数据加载器
test_loader = DataLoader(dataset=test_data,  # 数据集batch_size=64,  # 每个batch的大小shuffle=True,  # 是否打乱数据num_workers=0,  # 加载数据的线程数drop_last=False  # 是否丢弃最后一个batch
)

​ 使用 tensorboard 可视化每个批次。

writer = SummaryWriter('logs')step = 0
# 遍历test_loader中的数据
for data in test_loader:# 获取每批次的数据与标签imgs, targets = data# 将imgs写入tensorboardwriter.add_images('test_data', imgs, step)step += 1writer.close()

​ 拖动滑动条,可看见第 32 step 的批次图像如下。

image-20250521113345428

​ 将滑动条拖动到最右侧,发现该批次只剩余 16 张图像,因为 10000 % 64 = 16。

image-20250521113614564

相关文章:

  • 用Recommenders,实现个性化推荐
  • Socket编程——TCP
  • 协议大和解:ETHERCAT转CANopen网关配置
  • 打卡第二十四天
  • 2025年Y2大型游乐设施操作证备考练习题
  • WordPress Elementor零基础教程
  • 【Java微服务组件】异步通信P2—Kafka与消息
  • 如何设计智慧工地系统的数据库?
  • JVM梳理(逻辑清晰)
  • RL电路的响应
  • 阿里云数据盘级别
  • 在 Excel xll 自动注册操作 中使用东方仙盟软件————仙盟创梦IDE
  • LVLM-AFAH论文精读
  • 标准IO(2)、文件IO
  • API面临哪些风险,如何做好API安全?
  • C语言指针深入详解(六):sizeof和strlen的对比,【题解】数组和指针笔试题解析、指针运算笔试题解析
  • 海洋探测利器:HY - 2C 卫星
  • 【已解决】docker search --limit 1 centos Error response from daemon
  • 逆向学习笔记1
  • Spring AI 1.0 GA 于 2025 年 5 月 20 日正式发布,都有哪些特性?