当前位置: 首页 > news >正文

深度学习(6)python数据处理

2. 数据处理核心模块
2.1 utils.data:数据集定义与批量加载

utils.data是 PyTorch 中处理数据集的基础模块,分为Dataset(定义数据结构)和DataLoader(批量加载)两部分,二者需配合使用。

2.1.1 Dataset:自定义数据集
  • 核心作用:定义数据集的存储格式、样本读取逻辑,解决 “数据如何组织” 的问题。
  • 实现要求:必须继承torch.utils.data.Dataset,并重写以下 3 个方法:
    1. __init__:初始化数据与标签(如从 numpy 数组、文件加载);
    2. __getitem__(self, index):按索引index获取单个样本,需将数据转换为 PyTorch 的Tensor类型;
    3. __len__:返回数据集的总样本数。
  • 代码示例

    python

    import torch
    from torch.utils import data
    import numpy as npclass TestDataset(data.Dataset):def __init__(self):self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])  # 2维数据self.Label = np.asarray([0,1,0,1,2])  # 对应标签def __getitem__(self, index):txt = torch.from_numpy(self.Data[index])  # numpy→Tensorlabel = torch.tensor(self.Label[index])return txt, labeldef __len__(self):return len(self.Data)Test = TestDataset()
    print(Test[2])  # 调用__getitem__(2),输出单个样本:(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
    print(Test.__len__())  # 输出样本总数:5
    
2.1.2 DataLoader:批量加载数据
  • 核心作用:将Dataset输出的 “单个样本” 组合成 “批次数据”,支持多进程加载、数据打乱,解决 “数据如何高效供模型训练” 的问题。
  • 核心参数说明
参数名称说明默认值
dataset待加载的数据集(必须是Dataset类实例)-
batch_size每批数据的样本数量,决定一次训练的样本规模1
shuffle每次加载前是否打乱数据集顺序(训练集常用True,测试集常用FalseFalse
num_workers多进程加载数据的进程数,0代表不使用多进程(Windows 系统建议设为 0)0
drop_last若总样本数非batch_size整数倍,是否丢弃最后不足一批的样本False
collate_fn自定义批次数据的拼接逻辑,默认按 Tensor 维度拼接默认函数
  • 代码示例与运行结果

    python

    test_loader = data.DataLoader(Test, batch_size=2, shuffle=False, num_workers=0)
    for i, traindata in enumerate(test_loader):print('i:', i)Data, Label = traindataprint('data:', Data)print('Label:', Label)
    
    运行结果:

    plaintext

    i: 0
    data: tensor([[1, 2], [3, 4]], dtype=torch.int32)
    Label: tensor([0, 1], dtype=torch.int32)
    i: 1
    data: tensor([[2, 1], [3, 4]], dtype=torch.int32)
    Label: tensor([0, 1], dtype=torch.int32)
    i: 2
    data: tensor([[4, 5]], dtype=torch.int32)  # 因drop_last=False,保留最后1个样本
    Label: tensor([2], dtype=torch.int32)
    
  • 注意事项DataLoader本身不是迭代器,需通过iter()函数转换后才能用next()获取单批数据。
2.2 torchvision:图像数据专用处理工具

torchvision是 PyTorch 针对图像数据的扩展库,包含transforms(预处理 / 增强)和ImageFolder(多目录图像加载),专门解决图像数据的处理痛点。

2.2.1 transforms:数据预处理与增强
  • 核心作用:对图像数据(PIL Image 或 Tensor)执行标准化、裁剪、翻转等操作,提升模型泛化能力,适配模型输入格式。
  • 支持的操作分类
操作对象操作名称功能说明
PIL ImageResize/Scale调整图像尺寸,保持长宽比
CenterCrop/RandomCrop中心裁剪 / 随机裁剪(如RandomCrop(20)裁剪为 20×20 像素)
RandomHorizontalFlip以 50% 概率随机水平翻转图像
ColorJitter随机调整图像亮度、对比度、饱和度
ToTensor将 [0,255] 的 PIL Image 转为 [0,1] 的 Tensor,形状从 (H,W,C) 转为 (C,H,W)
TensorNormalize标准化:(Tensor - mean) / std(如文档中mean=(0.5,0.5,0.5)标准化到 [-1,1])
ToPILImage将 Tensor 转回 PIL Image,用于图像可视化
  • 关键工具:Compose
    • 作用:将多个transforms操作按顺序拼接成一个 “处理管道”,避免手动依次调用操作,类似nn.Sequential对网络层的组合。
    • 代码示例:

      python

      import torchvision.transforms as transforms
      transform = transforms.Compose([transforms.CenterCrop(10),  # 步骤1:中心裁剪为10×10transforms.RandomCrop(20, padding=0),  # 步骤2:随机裁剪为20×20transforms.ToTensor(),  # 步骤3:转为Tensortransforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))  # 步骤4:标准化
      ])
      
2.2.2 ImageFolder:多目录图像加载
  • 核心作用:读取按 “目录分类” 存储的图像数据(如data/cat/001.jpgdata/dog/002.jpg),自动将目录名作为类别标签,解决多分类图像的加载问题。
  • 代码示例

    python

    from torchvision import datasets, transforms
    from torch.utils import data# 定义预处理管道
    my_trans = transforms.Compose([transforms.RandomResizedCrop(224),  # 随机大小裁剪为224×224transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor()  # 转为Tensor
    ])# 加载多目录图像(root为数据根目录)
    train_data = datasets.ImageFolder(root='../data/torchvision_data', transform=my_trans)
    # 批量加载
    train_loader = data.DataLoader(train_data, batch_size=8, shuffle=True)# 可视化第一批数据
    import matplotlib.pyplot as plt
    import torchvision.utils as utils
    for i_batch, (img, label) in enumerate(train_loader):if i_batch == 0:print(label)  # 输出第一批样本的标签grid = utils.make_grid(img)  # 生成图像网格plt.imshow(grid.numpy().transpose((1,2,0)))  # 转换形状为(H,W,C)plt.show()utils.save_image(grid, 'test001.png')  # 保存图像break
    
  • 优势:无需手动定义标签,自动关联 “目录名 - 类别”,大幅简化多分类图像数据集的加载代码。
3. 可视化工具:TensorBoard

TensorBoard 是 Google TensorFlow 配套的可视化工具,PyTorch 通过torch.utils.tensorboard模块支持其功能,核心用于监控训练过程、可视化模型结构。

3.1 核心使用步骤(3 步)
  1. 实例化 SummaryWriter:指定日志存储目录(不存在则自动创建),负责记录可视化数据。

    python

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir='logs')  # 日志存于当前目录的logs文件夹
    
  2. 调用 add_xxx 接口:按需求记录不同类型的数据,格式为add_xxx(标签名, 数据对象, 迭代次数)
  3. 启动服务与访问
    • 命令行输入:tensorboard --logdir=logs --port 6006logdir为日志目录,port为端口);
    • 浏览器访问:http://localhost:6006(本机)或http://服务器IP:6006(远程服务器)。
3.2 常用可视化类型
可视化类型函数应用场景
Scalaradd_scalar可视化单一数值变化(如训练损失值、验证准确率随迭代次数的趋势)
Imageadd_image可视化图像数据(如输入图像、卷积层特征图)
Graphadd_graph可视化神经网络的计算图结构(展示层连接关系与参数)
Histogramadd_histogram可视化参数分布(如权重、偏置的分布变化)
Embeddingadd_embedding可视化高维数据的低维表示(如词向量、图像特征的 t-SNE 降维结果)
3.3 关键应用场景
  1. 可视化神经网络结构

    • 步骤:定义网络类(如含卷积层、全连接层的Net)→调用add_graph(model, input_tensor)→启动 TensorBoard 查看;
    • 价值:直观展示层顺序(如conv1→maxpool→conv2→fc1)与各层参数(如conv1: kernel_size=5, in_channels=1)。
  2. 可视化训练损失值

    • 代码示例(核心片段):

      python

      for epoch in range(60):  # 迭代60次# 训练步骤(前向传播→计算损失→反向传播→参数更新)loss = criterion(output, targets)optimizer.step()# 记录损失值(标签为“训练损失值”,迭代次数为epoch)writer.add_scalar('训练损失值', loss, epoch)
      
    • 结果:TensorBoard 中显示损失值随epoch下降的曲线,可判断模型是否收敛。
  3. 可视化卷积层特征图

    • 核心逻辑:遍历网络层,对卷积层输出的特征图用utils.make_grid生成网格→调用add_image记录;
    • 价值:观察不同卷积层提取的特征(浅层提取边缘、纹理,深层提取语义信息),辅助分析网络特征学习能力。

4. 关键问题

问题 1:utils.data中的DatasetDataLoader是什么关系?二者分别解决了数据处理中的什么问题?

答案:二者是 “依赖与协作” 关系,DataLoader需基于Dataset提供的单样本数据,实现批量加载。

  • Dataset解决 “数据如何定义” 的问题:通过继承data.Dataset并重写__init__(初始化数据 / 标签)、__getitem__(按索引取单样本并转 Tensor)、__len__(返回样本总数),定义数据集的存储格式与单样本读取逻辑,确保数据可被索引访问;
  • DataLoader解决 “数据如何高效加载” 的问题:通过接收Dataset实例,设置batch_size(批大小)、shuffle(是否打乱)、num_workers(多进程)等参数,将单样本组合成批次数据,减少 IO 开销,适配模型的批量训练需求(如文档中test_loader = data.DataLoader(Test, batch_size=2),将TestDataset实例)的单样本转为每批 2 个样本)。
问题 2:torchvision.transforms.Compose的核心作用是什么?为什么在数据预处理中推荐使用它?

答案Compose的核心作用是将多个数据预处理 / 增强操作按顺序拼接成一个 “处理管道”,实现 “一次调用完成多步操作”,无需手动依次执行单个transforms操作。推荐使用的原因有 3 点:

  1. 简化代码:例如 “CenterCrop→RandomCrop→ToTensor→Normalize”4 步操作,只需定义Compose实例,调用一次即可完成,避免重复写 4 行调用代码;
  2. 保证顺序一致性:严格按Compose中传入的顺序执行操作(如必须先ToTensorNormalize,因Normalize仅支持 Tensor 类型,若顺序颠倒会报错),避免手动调用时的顺序错误;
  3. 提升复用性:组合好的Compose实例可作为统一的预处理逻辑,在ImageFolder、自定义Dataset中复用,确保训练集、测试集的预处理方式完全一致,避免数据偏差。
问题 3:使用 TensorBoard 可视化训练损失值的核心步骤是什么?该可视化对模型训练有什么指导意义?

答案:核心步骤(基于文档):

  1. 初始化日志记录器writer = SummaryWriter(log_dir='logs'),指定日志存储目录;
  2. 训练循环中记录损失:在每个迭代周期(epoch)或批次(batch)后,调用writer.add_scalar('训练损失值', loss, global_step),其中loss为当前计算的损失值,global_step为迭代次数(如epochbatch_idx);
  3. 启动 TensorBoard 查看:命令行输入tensorboard --logdir=logs --port 6006,浏览器访问http://localhost:6006,在 “Scalar” 标签下查看损失曲线。

指导意义:

  1. 判断模型收敛状态:若损失值随global_step持续下降并趋于稳定,说明模型收敛;若损失值波动大或不下降,可能存在学习率过高、数据量不足等问题;
  2. 识别过拟合风险:若训练损失持续下降,但验证损失上升,说明模型过拟合,需调整正则化(如增加 dropout)或数据增强策略;
  3. 优化训练参数:通过损失曲线判断迭代次数是否足够(如损失未收敛则需增加epoch),或是否需调整学习率(如损失下降缓慢则适当提高学习率)。
http://www.dtcms.com/a/411726.html

相关文章:

  • 何做好网站建设销售中小学网站建设方案
  • 【实时Linux实战系列】延迟 SLI/SLO/SLA 设计与观测体系
  • NetworkPolicy 工作原理详解
  • Matlab通过GUI实现点云的中值滤波(附最简版)
  • 网站篡改搜索引擎js网站 目录 结构
  • 企业网站设计行业crm管理系统定制
  • 论文《Inference for Iterated GMM Under Misspecification》的例子3
  • 计算机图形图像技术实验报告
  • 编译DuckDB c++插件模板并加载运行
  • 做logo什么网站河田镇建设局网站
  • OA、PMES、TMES、SAP、PPM、CRM、DMS、HR系统
  • C语言 ——— 指针
  • 内力网站建设seo简单优化
  • 大模型-自编码器(AutoEncoder)原理(上)
  • Promise开发【进阶】
  • 建立网站需要备案吗网络科技公司起名字大全免费
  • solidworks ppo 试做1
  • Matter over Thread方案,如何助力智能家居生态互通?
  • 创办网站需要怎么做wordpress 点评类网站
  • 网站开发与运营怎么样0基础网站开发
  • mp3链接地址制作网站网站建设与开发选题
  • Dify 从入门到熟悉100 天的学习大纲
  • 为什么做的网站要续费东莞cms建站模板
  • 安徽省高速公路建设指挥部网站为什么实验楼网站上做实验这么卡
  • Java Web应用开发——第一章:Java Web概述测验
  • 北京网站建设哪家好免费正能量不良网站推荐
  • 高端网站建设公司有哪些项目南京房地产开发公司
  • 网络编程
  • VGG改进(11):基于WaveletAttention的模型详解
  • 安徽建筑大学学工在线网站代理网游