PyTorch DataLoader 接受的返回值类型
通常,数据集通过__getitem__
方法返回单个样本,而DataLoader负责将这些样本批量组合。以下是常见的返回值类型:
张量(Tensor):最常见的情况,返回一个或多个张量。DataLoader会自动将多个样本的张量堆叠成批次。
列表(List):可以返回一个列表,其中包含多个张量或其他类型。DataLoader会尝试将列表中的每个元素分别批量处理。
字典(Dictionary):返回一个字典,键是数据字段名,值是对应的张量或数据。DataLoader会按字段名分别批量处理。
元组(Tuple):返回一个元组,其中包含多个张量或其他类型。DataLoader会分别对元组中的每个元素进行批量处理。
命名元组(NamedTuple):类似于元组,但可以通过字段名访问,DataLoader处理方式与元组类似。
自定义数据类型:如果返回的是自定义类型,需要确保DataLoader知道如何组合这些类型。通常,自定义类型需要实现相应的拼接方法,或者使用
default_collate
函数能够处理。
示例1:返回张量
import torch
from torch.utils.data import Dataset, DataLoaderclass TensorDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return torch.tensor([index, index*2])dataset = TensorDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个批次的张量,形状为[2, 2]
示例2:返回元组(多个张量)
class TupleDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return torch.tensor(index), torch.tensor(index*2)dataset = TupleDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个元组,包含两个张量,每个张量的形状为[2]
示例3:返回字典
class DictDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return {'data': torch.tensor(index), 'label': torch.tensor(index*2)}dataset = DictDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个字典,包含两个键,每个键对应一个形状为[2]的张量
示例4:返回列表
class ListDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):return [torch.tensor(index), torch.tensor(index*2)]dataset = ListDataset()
dataloader = DataLoader(dataset, batch_size=2)for batch in dataloader:print(batch)# 输出:一个列表,包含两个张量,每个张量的形状为[2]
示例5:自定义collate_fn
def custom_collate_fn(batch):# 假设batch是多个样本的列表,每个样本是一个张量,但张量长度不同# 这里我们使用填充0到最大长度data = [item[0] for item in batch] # 假设每个样本是一个元组,第一个元素是张量labels = [item[1] for item in batch]# 填充数据lengths = [len(d) for d in data]max_len = max(lengths)padded_data = torch.zeros(len(batch), max_len)for i, d in enumerate(data):padded_data[i, :lengths[i]] = dreturn padded_data, torch.tensor(labels)class VariableLengthDataset(Dataset):def __len__(self):return 10def __getitem__(self, index):length = torch.randint(1, 5, (1,)).item()data = torch.randn(length)label = index % 2return data, labeldataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=custom_collate_fn)for batch in dataloader:print(batch)break
1. 基本数据类型
单个张量
# Dataset 返回单个张量
class SingleTensorDataset:def __getitem__(self, index):return torch.tensor([index, index*2, index*3])def __len__(self):return 100# DataLoader 会自动堆叠成批次
dataloader = DataLoader(SingleTensorDataset(), batch_size=4)
for batch in dataloader:print(batch.shape) # torch.Size([4, 3])
元组 (最常用)
# 返回 (input, target) 元组
class TupleDataset:def __getitem__(self, index):x = torch.randn(10) # 特征y = torch.tensor(index % 3) # 标签return x, y # 返回元组dataloader = DataLoader(TupleDataset(), batch_size=4)
for inputs, targets in dataloader:print(inputs.shape, targets.shape) # torch.Size([4, 10]), torch.Size([4])
字典
# 返回字典,键值对形式
class DictDataset:def __getitem__(self, index):return {'input_ids': torch.randint(0, 100, (10,)),'attention_mask': torch.ones(10),'labels': torch.tensor(index % 2)}dataloader = DataLoader(DictDataset(), batch_size=4)
for batch in dataloader:print(batch.keys()) # dict_keys(['input_ids', 'attention_mask', 'labels'])print(batch['input_ids'].shape) # torch.Size([4, 10])
列表
# 返回列表
class ListDataset:def __getitem__(self, index):return [torch.tensor(index), torch.randn(5), f"sample_{index}"]dataloader = DataLoader(ListDataset(), batch_size=4)
for batch in dataloader:print(batch) # [tensor, tensor, list_of_strings]
2. 复杂嵌套结构
嵌套字典
class NestedDictDataset:def __getitem__(self, index):return {'image': torch.randn(3, 224, 224),'metadata': {'filename': f'img_{index}.jpg','size': (224, 224),'timestamp': index * 1000},'labels': {'class': torch.tensor(index % 10),'bbox': torch.tensor([0.1, 0.2, 0.8, 0.9])}}dataloader = DataLoader(NestedDictDataset(), batch_size=4)
for batch in dataloader:print(batch['image'].shape) # torch.Size([4, 3, 224, 224])print(batch['metadata']['filename']) # 列表: ['img_0.jpg', ...]
命名元组
from collections import namedtupleSample = namedtuple('Sample', ['features', 'label', 'id'])class NamedTupleDataset:def __getitem__(self, index):return Sample(features=torch.randn(10),label=torch.tensor(index % 3),id=f"sample_{index}")dataloader = DataLoader(NamedTupleDataset(), batch_size=4)
for batch in dataloader:print(type(batch)) # <class '__main__.Sample'>print(batch.features.shape) # torch.Size([4, 10])
3. 自定义数据类型
自定义类实例
class DataSample:def __init__(self, data, target, meta):self.data = dataself.target = targetself.meta = metaclass CustomClassDataset:def __getitem__(self, index):return DataSample(data=torch.randn(10),target=torch.tensor(index % 2),meta={'index': index, 'name': f'sample_{index}'})# 需要自定义 collate_fn
def custom_collate(batch):return DataSample(data=torch.stack([sample.data for sample in batch]),target=torch.stack([sample.target for sample in batch]),meta=[sample.meta for sample in batch])dataloader = DataLoader(CustomClassDataset(), batch_size=4, collate_fn=custom_collate)
4. 混合数据类型
张量 + Python 基本类型
class MixedDataset:def __getitem__(self, index):return (torch.randn(10), # 张量index % 3, # Python整数f"sample_{index}", # 字符串[index, index*2], # 列表{'idx': index} # 字典)dataloader = DataLoader(MixedDataset(), batch_size=4)
for tensor_data, int_data, str_data, list_data, dict_data in dataloader:print(tensor_data.shape) # torch.Size([4, 10]) - 张量被堆叠print(int_data) # tensor([0, 1, 2, 3]) - 数字被转换为张量print(str_data) # ['sample_0', ...] - 字符串保持为列表
5. 特殊返回值处理
None 值处理
class DatasetWithNone:def __getitem__(self, index):if index % 5 == 0: # 每5个样本返回Nonereturn Nonereturn torch.randn(10), torch.tensor(index % 3)# 需要过滤 None 值
def filter_none_collate(batch):batch = [sample for sample in batch if sample is not None]return default_collate(batch) if batch else Nonedataloader = DataLoader(DatasetWithNone(), batch_size=4, collate_fn=filter_none_collate)
可变长度序列
class VariableLengthDataset:def __getitem__(self, index):length = torch.randint(5, 15, (1,)).item()sequence = torch.randn(length, 10) # 可变长度序列return sequence, torch.tensor(index % 3)# 使用 pad_sequence 处理可变长度
from torch.nn.utils.rnn import pad_sequencedef pad_collate(batch):sequences, labels = zip(*batch)padded_sequences = pad_sequence(sequences, batch_first=True)labels = torch.stack(labels)return padded_sequences, labelsdataloader = DataLoader(VariableLengthDataset(), batch_size=4, collate_fn=pad_collate)
6. 实际应用示例
计算机视觉任务
class VisionDataset:def __getitem__(self, index):return {'image': torch.randn(3, 224, 224), # 图像'label': torch.tensor(index % 1000), # 分类标签'bbox': torch.tensor([[10, 20, 100, 150]]), # 检测框'mask': torch.randn(224, 224) > 0.5, # 分割掩码'image_id': index # 图像ID}
NLP 任务
class NLPDataset:def __getitem__(self, index):return {'input_ids': torch.randint(0, 1000, (128,)),'attention_mask': torch.ones(128),'token_type_ids': torch.zeros(128),'labels': torch.randint(0, 2, (1,)),'text': f"这是第{index}个样本"}