当前位置: 首页 > news >正文

基于 PyTorch 构建 Dataset 与 DataLoader:从 TXT 文件读取到新增类别全流程指南

基于 PyTorch 构建 Dataset 与 DataLoader:从 TXT 文件读取到新增类别全流程指南

在深度学习计算机视觉任务中,数据加载与预处理是模型训练的基础环节,直接影响模型的训练效率与最终性能。PyTorch 作为主流深度学习框架,提供了Dataset和DataLoader两大核心组件,帮助开发者高效管理数据流程。本文将以 “读取 TXT 文件中的图片路径与标签、处理图片数据” 为核心场景,从零开始构建完整的数据加载 pipeline,并重点分析 “新增数据集类别(如新增‘汉堡’类别)” 时需要重新运行的关键步骤,同时深入探讨数据加载过程中的优化技巧与常见问题解决方案。全文内容涵盖理论讲解、代码实现、案例演示与拓展思考,适合深度学习初学者与进阶开发者参考。

一、核心概念解析:为什么需要 Dataset 与 DataLoader?

在正式开始代码实现前,我们首先需要明确Dataset和DataLoader的设计初衷与核心作用 —— 理解这两个组件的本质,能帮助我们在后续开发中更灵活地应对复杂数据场景。

1.1 数据加载的核心痛点

在深度学习训练中,直接使用原始数据(如图片文件)存在以下痛点:

  • 数据量庞大:一张图片通常占用数 MB 存储空间,若训练集包含 10 万张图片,直接一次性加载到内存会导致内存溢出(OOM);
  • 数据格式不统一:图片可能存在不同尺寸(如 224×224、512×512)、不同通道(RGB/Gray)、不同格式(JPG/PNG),需要统一预处理;
  • 标签与数据关联复杂:图片路径与类别标签通常分开存储(如 TXT 文件、CSV 文件),需要高效关联并避免标签错乱;
  • 训练效率需求:训练过程中需要随机打乱数据(shuffle)、按批次加载(batch)、多线程并行处理,以提升 GPU 利用率。

1.2 Dataset:数据的 “容器” 与 “接口”

Dataset是 PyTorch 中抽象数据类,其核心作用是定义数据的 “读取规则”,即如何从原始数据(如 TXT 文件、图片文件夹)中获取单条数据(如一张图片 + 对应的标签)。开发者需要通过继承Dataset类,并实现以下两个核心方法:

  • __len__():返回数据集的总样本数量,用于DataLoader判断数据集的迭代次数;
  • __getitem__(idx):根据索引idx返回单条样本(通常是(image, label)的元组),是数据加载的核心逻辑所在。

简单来说,Dataset就像一个 “数据仓库管理员”,负责按规则取出单条数据,但不负责 “批量运输” 和 “并行调度”—— 这部分工作由DataLoader完成。

1.3 DataLoader:数据的 “运输队” 与 “调度员”

DataLoader是基于Dataset的上层组件,其核心作用是优化数据加载的效率,解决 “批量加载”“并行处理”“数据打乱” 等工程问题。它通过以下参数实现高效调度:

  • dataset:传入自定义的Dataset实例,指定数据来源;
  • batch_size:指定每次加载的样本数量(如 32、64),平衡内存占用与训练效率;
  • shuffle:布尔值,指定是否在每个 epoch 开始前打乱数据顺序,避免模型学习到数据的顺序依赖;
  • num_workers:指定用于数据加载的子进程数量(通常设为 CPU 核心数的 1-2 倍),实现并行预处理,减少 GPU 等待时间;
  • pin_memory:布尔值,若为True,则将加载的数据固定到内存中,加速数据从 CPU 到 GPU 的传输(建议开启,尤其当 GPU 显存充足时);
  • drop_last:布尔值,若为True,则丢弃最后一个不足batch_size的样本(避免训练时因批次大小不一致导致的报错)。

Dataset与DataLoader的配合流程可总结为:DataLoader通过多线程调用Dataset的__getitem__()方法获取单条数据,再将多条数据打包成batch,最终传输给 GPU 用于模型训练。

二、前期准备:环境搭建与数据组织

在编写代码前,我们需要完成环境搭建与数据组织 —— 这是确保后续流程顺利进行的基础,尤其需要注意数据路径的规范性(避免后续因路径错误导致的加载失败)。

2.1 环境搭建:PyTorch 与依赖库安装

本文使用的核心依赖库包括:

  • PyTorch:核心深度学习框架,提供Dataset与DataLoader;
  • torchvision:PyTorch 官方计算机视觉工具库,提供图片预处理(如Resize、ToTensor)、常用数据集等;
  • Pillow(PIL):Python 图像处理库,用于读取和处理图片文件;
  • numpy:数值计算库,用于图片数组的处理;
  • matplotlib:可视化库,用于展示加载后的图片与标签,验证数据正确性。
2.1.1 安装命令(conda 环境为例)

# 创建并激活conda环境

conda create -n pytorch_data python=3.9

conda activate pytorch_data

# 安装PyTorch(根据GPU型号选择,此处以CUDA 11.8为例,CPU版本可省略cuda参数)

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

# 安装其他依赖库

conda install pillow numpy matplotlib

2.1.2 环境验证

安装完成后,通过以下代码验证环境是否正常:

import torch

import torchvision

from PIL import Image

import numpy as np

import matplotlib.pyplot as plt

# 验证PyTorch版本与GPU可用性

print(f"PyTorch版本:{torch.__version__}")

print(f"GPU是否可用:{torch.cuda.is_available()}")

print(f"GPU设备数量:{torch.cuda.device_count()}")

if torch.cuda.is_available():

print(f"当前GPU型号:{torch.cuda.get_device_name(0)}")

# 验证Pillow与matplotlib

img = Image.new("RGB", (100, 100), color="red")

plt.imshow(img)

plt.title("环境验证")

plt.show()

若代码能正常输出 PyTorch 版本、GPU 信息,并显示红色图片,则环境搭建成功。

2.2 数据组织:图片与 TXT 标签文件的规范结构

本文以 “食物分类” 任务为例,初始数据集包含 3 个类别:“苹果”(apple)、“香蕉”(banana)、“橙子”(orange)。后续将新增 “汉堡”(hamburger)类别,因此数据组织需考虑可扩展性。

2.2.1 数据文件夹结构

建议采用以下结构(路径可自定义,此处以./data/food/为例):

./data/food/

├── images/ # 所有图片存储文件夹

│ ├── apple/ # 苹果类别图片(100张示例)

│ │ ├── apple_001.jpg

│ │ ├── apple_002.jpg

│ │ └── ...

│ ├── banana/ # 香蕉类别图片(100张示例)

│ │ ├── banana_001.jpg

│ │ └── ...

│ └── orange/ # 橙子类别图片(100张示例)

│ ├── orange_001.jpg

│ └── ...

├── train.txt # 训练集标签文件(图片路径+标签)

└── val.txt # 验证集标签文件(图片路径+标签)

2.2.2 TXT 标签文件的格式规范

train.txt和val.txt的每行代表一条样本,格式为 “图片相对路径 类别索引”,示例如下(train.txt):

./data/food/images/apple/apple_001.jpg 0

./data/food/images/apple/apple_002.jpg 0

./data/food/images/banana/banana_001.jpg 1

./data/food/images/banana/banana_002.jpg 1

./data/food/images/orange/orange_001.jpg 2

./data/food/images/orange/orange_002.jpg 2

其中:

  • 图片路径:建议使用相对路径(相对于代码运行目录),避免绝对路径导致的环境迁移问题;
  • 类别索引:需与类别名称一一对应,建议通过字典记录映射关系(如{"apple":0, "banana":1, "orange":2}),后续新增类别时只需扩展该字典。
2.2.3 TXT 文件的生成方法(Python 脚本)

若手动编写 TXT 文件效率低,可通过以下 Python 脚本自动生成train.txt和val.txt(支持按比例划分训练集与验证集,此处以 8:2 为例):

import os

import random

from pathlib import Path

# 1. 配置参数

data_root = Path("./data/food") # 数据根目录

image_dir = data_root / "images" # 图片文件夹路径

train_txt_path = data_root / "train.txt" # 训练集TXT路径

val_txt_path = data_root / "val.txt" # 验证集TXT路径

val_ratio = 0.2 # 验证集比例(20%)

seed = 42 # 随机种子(保证划分结果可复现)

# 2. 定义类别与索引的映射(后续新增类别需在此扩展)

class_to_idx = {

"apple": 0,

"banana": 1,

"orange": 2

}

# 3. 遍历所有类别文件夹,收集图片路径与标签

all_samples = []

for class_name, class_idx in class_to_idx.items():

class_dir = image_dir / class_name # 单个类别图片文件夹

# 遍历文件夹下所有图片文件(仅保留JPG/PNG格式)

for img_name in os.listdir(class_dir):

if img_name.lower().endswith((".jpg", ".png")):

img_path = str(class_dir / img_name) # 图片完整路径

all_samples.append((img_path, class_idx))

# 4. 打乱样本顺序(固定种子确保可复现)

random.seed(seed)

random.shuffle(all_samples)

# 5. 划分训练集与验证集

val_size = int(len(all_samples) * val_ratio)

train_samples = all_samples[val_size:]

val_samples = all_samples[:val_size]

# 6. 写入TXT文件

def write_txt(samples, txt_path):

with open(txt_path, "w", encoding="utf-8") as f:

for img_path, label in samples:

f.write(f"{img_path} {label}\n")

write_txt(train_samples, train_txt_path)

write_txt(val_samples, val_txt_path)

# 7. 输出日志

print(f"数据集划分完成!")

print(f"总样本数:{len(all_samples)}")

print(f"训练集样本数:{len(train_samples)}({len(train_samples)/len(all_samples)*100:.1f}%)")

print(f"验证集样本数:{len(val_samples)}({len(val_samples)/len(all_samples)*100:.1f}%)")

print(f"训练集TXT路径:{train_txt_path}")

print(f"验证集TXT路径:{val_txt_path}")

运行该脚本后,将在./data/food/目录下生成train.txt和val.txt,无需手动编写标签文件,大幅提升效率。

三、核心实现:自定义 Dataset 类(读取 TXT 与处理图片)

自定义Dataset类是数据加载的核心步骤,需实现 “读取 TXT 文件→解析图片路径与标签→加载图片→预处理图片” 的完整逻辑。本节将分步骤讲解CustomImageDataset的实现,并加入详细注释,确保代码可复用、可扩展。

3.1 自定义 Dataset 类的完整代码

import torch

from torch.utils.data import Dataset

from torchvision import transforms

from PIL import Image

import os

from pathlib import Path

class CustomImageDataset(Dataset):

def __init__(self, txt_path, class_to_idx, transform=None):

"""

自定义图片数据集的初始化方法

参数说明:

- txt_path: str/Tuple[Path],TXT标签文件的路径(如./data/food/train.txt)

- class_to_idx: dict,类别名称到索引的映射(如{"apple":0, "banana":1})

- transform: torchvision.transforms.Compose,图片预处理流水线(可选)

"""

# 1. 存储输入参数

self.txt_path = Path(txt_path)

self.class_to_idx = class_to_idx

self.transform = transform # 若为None,将使用默认预处理

# 2. 验证TXT文件是否存在

if not self.txt_path.exists():

raise FileNotFoundError(f"TXT文件不存在:{self.txt_path}")

# 3. 从TXT文件中加载样本列表(图片路径+标签)

self.samples = self._load_samples_from_txt()

# 4. 定义默认预处理(若未传入transform)

if self.transform is None:

self.transform = transforms.Compose([

transforms.Resize((224, 224)), # 统一图片尺寸为224×224(适配大多数预训练模型)

transforms.ToTensor(), # 将PIL图片转为Tensor(维度:C×H×W,数值范围[0,1])

transforms.Normalize( # 标准化(使用ImageNet数据集的均值和标准差,适配预训练模型)

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

def _load_samples_from_txt(self):

"""

从TXT文件中加载样本列表(内部辅助方法)

返回:list,每个元素为 tuple(img_path: str, label: int)

"""

samples = []

with open(self.txt_path, "r", encoding="utf-8") as f:

lines = f.readlines() # 读取所有行

for line_idx, line in enumerate(lines, 1): # 枚举行号(从1开始,便于报错定位)

line = line.strip() # 去除行首尾的空格、换行符

if not line: # 跳过空行

continue

# 按空格分割“图片路径”和“标签”(假设路径中不含空格)

parts = line.split()

if len(parts) != 2:

raise ValueError(

f"TXT文件第{line_idx}行格式错误!正确格式:'图片路径 类别索引',当前行:{line}"

)

img_path, label_str = parts

# 验证图片路径是否存在

if not os.path.exists(img_path):

raise FileNotFoundError(

f"TXT文件第{line_idx}行:图片不存在!路径:{img_path}"

)

# 将标签转为整数

try:

label = int(label_str)

except ValueError:

raise ValueError(

f"TXT文件第{line_idx}行:类别索引必须为整数!当前值:{label_str}"

)

# 验证标签是否在合法范围内(避免标签错误)

valid_labels = set(self.class_to_idx.values())

if label not in valid_labels:

raise ValueError(

f"TXT文件第{line_idx}行:类别索引{label}无效!合法索引:{sorted(valid_labels)}"

)

# 添加到样本列表

samples.append((img_path, label))

# 验证样本列表非空

if len(samples) == 0:

raise RuntimeError(f"TXT文件{self.txt_path}中无有效样本!")

return samples

def __len__(self):

"""返回数据集的总样本数量(必须实现)"""

return len(self.samples)

def __getitem__(self, idx):

"""

根据索引idx返回单条样本(必须实现)

参数:idx: int,样本索引

返回:tuple(image: torch.Tensor, label: int)

"""

# 1. 获取当前样本的图片路径与标签

img_path, label = self.samples[idx]

# 2. 加载图片(使用Pillow,保持图片为RGB格式)

try:

image = Image.open(img_path).convert("RGB") # 转为RGB(避免灰度图维度问题)

except Exception as e:

raise RuntimeError(f"加载图片{img_path}失败!错误信息:{str(e)}")

# 3. 图片预处理(应用transform</doubaocanvas>

http://www.dtcms.com/a/355337.html

相关文章:

  • AI大模型企业落地指南-笔记02
  • Spring 框架中事务传播行为的定义
  • 146. LRU缓存
  • python使用sqlcipher4对sqlite数据库加密
  • 【论文阅读】基于人工智能的下肢外骨骼辅助康复方法研究综述
  • 【电源专题】隐形守护者:防爆锂电池如何守护高危环境的安全防线
  • UE5提升分辨率和帧率的方法
  • 网站日志里面老是出现{pboot:if((\x22file_put_co\x22.\x22ntents\x22)(\x22temp.php\x22.....
  • Leetcode 深度优先搜索 (15)
  • 【大前端】React Native(RN)跨端的原理
  • 比较两个字符串的大小
  • 使用CDN后如何才不暴露IP
  • EtherNet/IP 转 Modbus 协议网关(三格电子)
  • SOME/IP-SD通信中的信息安全保证
  • leetcode_73 矩阵置零
  • (LeetCode 面试经典 150 题) 103. 二叉树的锯齿形层序遍历(广度优先搜索bfs)
  • [n8n] 工作流数据库管理SQLite | 数据访问层-REST API服务
  • 解决PyCharm打开PowerShell终端报错找不到conda-hook.ps1文件
  • 前端javascript在线生成excel,word模板-通用场景(免费)
  • Spring Boot 定时任务入门
  • 使用Java实现PDF文件安全检测:防止恶意内容注入
  • ubuntu20安装lammps
  • PDFMathTranslate,完全免费的电脑 PDF 文档翻译软件
  • 怎么保护信息安全?技术理论分析
  • Shell 脚本编程规范与变量
  • [调试][实现][原理]用Golang实现建议断点调试器
  • 裸金属服务器与虚拟机、物理机的核心差异是什么?
  • 鸿蒙Harmony-从零开始构建类似于安卓GreenDao的ORM数据库(二)
  • Kea DHCP高危漏洞CVE-2025-40779:单个数据包即可导致服务器崩溃
  • 获取小红书某个用户列表