当前位置: 首页 > news >正文

PyTorch DataLoader 接受的返回值类型

通常,数据集通过__getitem__方法返回单个样本,而DataLoader负责将这些样本批量组合。以下是常见的返回值类型:

  1. 张量(Tensor):最常见的情况,返回一个或多个张量。DataLoader会自动将多个样本的张量堆叠成批次。

  2. 列表(List):可以返回一个列表,其中包含多个张量或其他类型。DataLoader会尝试将列表中的每个元素分别批量处理。

  3. 字典(Dictionary):返回一个字典,键是数据字段名,值是对应的张量或数据。DataLoader会按字段名分别批量处理。

  4. 元组(Tuple):返回一个元组,其中包含多个张量或其他类型。DataLoader会分别对元组中的每个元素进行批量处理。

  5. 命名元组(NamedTuple):类似于元组,但可以通过字段名访问,DataLoader处理方式与元组类似。

  6. 自定义数据类型:如果返回的是自定义类型,需要确保DataLoader知道如何组合这些类型。通常,自定义类型需要实现相应的拼接方法,或者使用default_collate函数能够处理。

示例1:返回张量

import torch
from torch.utils.data import Dataset, DataLoaderclass TensorDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return torch.tensor([index, index*2])dataset = TensorDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个批次的张量,形状为[2, 2]

示例2:返回元组(多个张量)

class TupleDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return torch.tensor(index), torch.tensor(index*2)dataset = TupleDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个元组,包含两个张量,每个张量的形状为[2]

示例3:返回字典

class DictDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return {'data': torch.tensor(index), 'label': torch.tensor(index*2)}dataset = DictDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个字典,包含两个键,每个键对应一个形状为[2]的张量

示例4:返回列表

class ListDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return [torch.tensor(index), torch.tensor(index*2)]dataset = ListDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个列表,包含两个张量,每个张量的形状为[2]

示例5:自定义collate_fn

def custom_collate_fn(batch):# 假设batch是多个样本的列表,每个样本是一个张量,但张量长度不同# 这里我们使用填充0到最大长度data = [item[0] for item in batch]  # 假设每个样本是一个元组,第一个元素是张量labels = [item[1] for item in batch]# 填充数据lengths = [len(d) for d in data]max_len = max(lengths)padded_data = torch.zeros(len(batch), max_len)for i, d in enumerate(data):padded_data[i, :lengths[i]] = dreturn padded_data, torch.tensor(labels)class VariableLengthDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):length = torch.randint(1, 5, (1,)).item()data = torch.randn(length)label = index % 2return data, labeldataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=custom_collate_fn)for batch in dataloader:print(batch)break

1. 基本数据类型

单个张量

# Dataset 返回单个张量
class SingleTensorDataset:def __getitem__(self, index):return torch.tensor([index, index*2, index*3])def __len__(self):return 100# DataLoader 会自动堆叠成批次
dataloader = DataLoader(SingleTensorDataset(), batch_size=4)
for batch in dataloader:print(batch.shape)  # torch.Size([4, 3])

元组 (最常用)

# 返回 (input, target) 元组
class TupleDataset:def __getitem__(self, index):x = torch.randn(10)  # 特征y = torch.tensor(index % 3)  # 标签return x, y  # 返回元组dataloader = DataLoader(TupleDataset(), batch_size=4)
for inputs, targets in dataloader:print(inputs.shape, targets.shape)  # torch.Size([4, 10]), torch.Size([4])

字典

# 返回字典,键值对形式
class DictDataset:def __getitem__(self, index):return {'input_ids': torch.randint(0, 100, (10,)),'attention_mask': torch.ones(10),'labels': torch.tensor(index % 2)}dataloader = DataLoader(DictDataset(), batch_size=4)
for batch in dataloader:print(batch.keys())  # dict_keys(['input_ids', 'attention_mask', 'labels'])print(batch['input_ids'].shape)  # torch.Size([4, 10])

列表

# 返回列表
class ListDataset:def __getitem__(self, index):return [torch.tensor(index), torch.randn(5), f"sample_{index}"]dataloader = DataLoader(ListDataset(), batch_size=4)
for batch in dataloader:print(batch)  # [tensor, tensor, list_of_strings]

2. 复杂嵌套结构

嵌套字典

class NestedDictDataset:def __getitem__(self, index):return {'image': torch.randn(3, 224, 224),'metadata': {'filename': f'img_{index}.jpg','size': (224, 224),'timestamp': index * 1000},'labels': {'class': torch.tensor(index % 10),'bbox': torch.tensor([0.1, 0.2, 0.8, 0.9])}}dataloader = DataLoader(NestedDictDataset(), batch_size=4)
for batch in dataloader:print(batch['image'].shape)  # torch.Size([4, 3, 224, 224])print(batch['metadata']['filename'])  # 列表: ['img_0.jpg', ...]

命名元组

from collections import namedtupleSample = namedtuple('Sample', ['features', 'label', 'id'])class NamedTupleDataset:def __getitem__(self, index):return Sample(features=torch.randn(10),label=torch.tensor(index % 3),id=f"sample_{index}")dataloader = DataLoader(NamedTupleDataset(), batch_size=4)
for batch in dataloader:print(type(batch))  # <class '__main__.Sample'>print(batch.features.shape)  # torch.Size([4, 10])

3. 自定义数据类型

自定义类实例

class DataSample:def __init__(self, data, target, meta):self.data = dataself.target = targetself.meta = metaclass CustomClassDataset:def __getitem__(self, index):return DataSample(data=torch.randn(10),target=torch.tensor(index % 2),meta={'index': index, 'name': f'sample_{index}'})# 需要自定义 collate_fn
def custom_collate(batch):return DataSample(data=torch.stack([sample.data for sample in batch]),target=torch.stack([sample.target for sample in batch]),meta=[sample.meta for sample in batch])dataloader = DataLoader(CustomClassDataset(), batch_size=4, collate_fn=custom_collate)

4. 混合数据类型

张量 + Python 基本类型

class MixedDataset:def __getitem__(self, index):return (torch.randn(10),           # 张量index % 3,                 # Python整数f"sample_{index}",         # 字符串[index, index*2],          # 列表{'idx': index}             # 字典)dataloader = DataLoader(MixedDataset(), batch_size=4)
for tensor_data, int_data, str_data, list_data, dict_data in dataloader:print(tensor_data.shape)  # torch.Size([4, 10]) - 张量被堆叠print(int_data)           # tensor([0, 1, 2, 3]) - 数字被转换为张量print(str_data)           # ['sample_0', ...] - 字符串保持为列表

5. 特殊返回值处理

None 值处理

class DatasetWithNone:def __getitem__(self, index):if index % 5 == 0:  # 每5个样本返回Nonereturn Nonereturn torch.randn(10), torch.tensor(index % 3)# 需要过滤 None 值
def filter_none_collate(batch):batch = [sample for sample in batch if sample is not None]return default_collate(batch) if batch else Nonedataloader = DataLoader(DatasetWithNone(), batch_size=4, collate_fn=filter_none_collate)

可变长度序列

class VariableLengthDataset:def __getitem__(self, index):length = torch.randint(5, 15, (1,)).item()sequence = torch.randn(length, 10)  # 可变长度序列return sequence, torch.tensor(index % 3)# 使用 pad_sequence 处理可变长度
from torch.nn.utils.rnn import pad_sequencedef pad_collate(batch):sequences, labels = zip(*batch)padded_sequences = pad_sequence(sequences, batch_first=True)labels = torch.stack(labels)return padded_sequences, labelsdataloader = DataLoader(VariableLengthDataset(), batch_size=4, collate_fn=pad_collate)

6. 实际应用示例

计算机视觉任务

class VisionDataset:def __getitem__(self, index):return {'image': torch.randn(3, 224, 224),      # 图像'label': torch.tensor(index % 1000),    # 分类标签'bbox': torch.tensor([[10, 20, 100, 150]]),  # 检测框'mask': torch.randn(224, 224) > 0.5,    # 分割掩码'image_id': index                       # 图像ID}

NLP 任务

class NLPDataset:def __getitem__(self, index):return {'input_ids': torch.randint(0, 1000, (128,)),'attention_mask': torch.ones(128),'token_type_ids': torch.zeros(128),'labels': torch.randint(0, 2, (1,)),'text': f"这是第{index}个样本"}

http://www.dtcms.com/a/403109.html

相关文章:

  • rust slint android 安卓
  • 网站后台建设怎么进入超级优化小说
  • 游戏对象AI类型释义
  • Harnessing Text Insights with Visual Alignment for Medical Image Segmentation
  • 网上做网站 干对缝儿生意外贸网站推广优化
  • 【Java后端】MyBatis 和 MyBatis-Plus (MP) 的区别
  • iOS PPBluetoothKit接入无法找到头文件问题
  • leetcode orb slam3 3/99--> leetcode49 Group Anagrams
  • c# 读取xml到datagridview
  • 开源的 CSS 动画库
  • (三)过滤器及组件化开发
  • [NewBeeBox] A JavaScript error occurred in the main process
  • 【LangGraph】ReAct构建-LangGraph简单实现
  • 做毕业设计哪个网站好网站怎样做百度推广
  • Python高效合并Excel多Sheet工作表,告别繁琐手动操作
  • 自动跳转到wap网站外贸网站建设制作设计案例
  • 【Linux】 服务器无 sz 命令时的文件传输和日志查看方案
  • 【TVM 教程】设置 RPC 系统
  • 在ssh远程连接的autodl服务器(中国无root权限服务器)上使用copilt的Claude模型
  • Ansible 自动化运维:集中化管理服务器实战指南
  • 自动化运维工具 Ansible 集中化管理服务器
  • 【好玩的开源项目】使用Docker部署LMS轻量级音乐服务器
  • Netty从0到1系列之RPC通信
  • Coze源码分析-资源库-创建数据库-后端源码-安全与错误处理
  • LeetCode:52.腐烂的橘子
  • LeetCode算法日记 - Day 52: 求根节点到叶节点数字之和、二叉树剪枝
  • 四种方法解决——力扣189.轮转数组
  • ⸢ 伍-Ⅱ⸥ ⤳ 默认安全治理实践:水平越权检测 前端安全防控
  • 力扣856
  • Leetcode94.二叉数的中序遍历练习