PyTorch Dataloader工作原理 之 default collate_fn操作
场景设定
假设我们有一个 Dataset
,它的 __getitem__
方法返回一个包含两项的元组:一个 (3, 32, 32)
的图像张量和一个整数标签。
import torch
from torch.utils.data import Datasetclass MyImageDataset(Dataset):def __len__(self):return 10 # 数据集里有10个样本def __getitem__(self, idx):# 模拟返回一个图像张量和一个标签# 图像形状: (通道, 高, 宽)image_tensor = torch.randn(3, 32, 32)label = idx % 2 # 标签是 0 或 1return image_tensor, labeldataset = MyImageDataset()
现在,我们创建一个 DataLoader
,设置 batch_size=4
,并且不指定 collate_fn
,这样它就会使用默认的 default_collate
。
from torch.utils.data import DataLoaderdata_loader = DataLoader(dataset, batch_size=4, shuffle=False)
当我们在 for
循环中迭代 data_loader
时,default_collate
会在幕后执行以下步骤:
default_collate
的工作流程
第 1 步:获取样本列表 (List of Samples)
DataLoader
首先从 dataset
中获取一个批次大小(这里是 4)的样本。它会得到一个列表,列表的每个元素都是 __getitem__
的返回值。
这个列表(我们称之为 batch_list
)看起来像这样:
batch_list = [( <Tensor_0 shape=(3,32,32)>, 0 ), # 第一个样本 (image_0, label_0)( <Tensor_1 shape=(3,32,32)>, 1 ), # 第二个样本 (image_1, label_1)( <Tensor_2 shape=(3,32,32)>, 0 ), # 第三个样本 (image_2, label_2)( <Tensor_3 shape=(3,32,32)>, 1 ) # 第四个样本 (image_3, label_3)
]
这是一个长度为 4 的列表,每个元素都是一个元组。
第 2 步:转置列表 (Transpose the List)
这是最关键的一步,也就是“堆叠对应部分”的第一阶段。default_collate
会“转置”这个列表。它将所有元组的第 0 个元素收集在一起,形成一个新的列表;然后将所有元组的第 1 个元素收集在一起,形成另一个新的列表。
- 原始结构:
List[Tuple[Image, Label]]
- 转置后结构:
Tuple[List[Image], List[Label]]
转置后的结果如下:
# 这是一个概念上的表示,不是实际代码
transposed_batch = (# 所有样本的第0个元素 (图像)[ <Tensor_0>, <Tensor_1>, <Tensor_2>, <Tensor_3> ],# 所有样本的第1个元素 (标签)[ 0, 1, 0, 1 ]
)
现在我们得到了一个元组,元组里有两个列表:一个是包含 4 个图像张量的列表,另一个是包含 4 个整数标签的列表。
第 3 步:打包/堆叠每一部分 (Collate/Stack Each Part)
现在,default_collate
会遍历转置后元组中的每一个列表,并尝试将它们“打包”成一个批次张量。
对于第一个列表(图像张量列表):
- 输入是
[ <Tensor_0 shape=(3,32,32)>, <Tensor_1 shape=(3,32,32)>, ... ]
。 default_collate
发现这些都是 PyTorch 张量,于是它会使用torch.stack(list_of_tensors, dim=0)
。torch.stack
会创建一个新的维度(批次维度)在第 0 维,然后将这些张量沿着这个新维度拼接起来。- 结果: 一个单一的张量,形状为
(4, 3, 32, 32)
。这里的4
就是batch_size
。
对于第二个列表(标签列表):
- 输入是
[0, 1, 0, 1]
。 default_collate
发现这些是 Python 的数字。它会先将它们转换成一个张量列表[tensor(0), tensor(1), tensor(0), tensor(1)]
。- 然后,它同样对这个张量列表使用
torch.stack
。 - 结果: 一个单一的张量,形状为
(4,)
。
第 4 步:返回最终批次
最后,default_collate
将打包好的各个部分重新组合成一个元组(保持原始的结构),并返回。
所以,你在 for
循环中得到的 batch
变量实际上是:
# batch[0] 是图像批次, batch[1] 是标签批次
batch = (<Tensor shape=(4, 3, 32, 32)>,<Tensor shape=(4,)>
)
这就是为什么你可以这样解包:
for images_batch, labels_batch in data_loader:
如果样本是字典呢?
default_collate
同样智能地处理字典。如果你的 __getitem__
返回一个字典:
def __getitem__(self, idx):return {'image': torch.randn(3, 32, 32),'label': idx % 2}
那么 default_collate
的工作流程会是:
-
获取样本列表:
batch_list = [{'image': <Tensor_0>, 'label': 0},{'image': <Tensor_1>, 'label': 1},{'image': <Tensor_2>, 'label': 0},{'image': <Tensor_3>, 'label': 1} ]
-
转置/重组: 它会将所有字典的
'image'
值收集到一个列表,将所有'label'
值收集到另一个列表。 -
打包/堆叠:
[<Tensor_0>, <Tensor_1>, ...]
被堆叠成一个(4, 3, 32, 32)
的张量。[0, 1, 0, 1]
被堆叠成一个(4,)
的张量。
-
返回最终批次: 它会返回一个字典,字典的键和样本字典的键相同,但值是堆叠后的批次张量。
batch = {'image': <Tensor shape=(4, 3, 32, 32)>,'label': <Tensor shape=(4,)> }
总结
default_collate
的“堆叠对应部分”是一个递归的过程:
- 它检查一批样本的数据结构(是元组、字典还是其他)。
- 它“转置”这个结构,将所有样本的“第一部分”放在一起,所有样本的“第二部分”放在一起,以此类推。
- 对于每个集合起来的部分,它会根据其数据类型应用
torch.stack
(如果是张量、数字等)或递归地调用collate
过程(如果是更复杂的嵌套结构)。
这个过程的前提是,所有待堆叠的张量必须具有完全相同的形状。如果形状不一(例如,变长的文本序列),torch.stack
就会失败,这就是为什么在这种情况下你需要提供一个自定义的 collate_fn
来进行 padding 等操作。