torch.cat 函数介绍
torch.cat
是 PyTorch 中用于将多个张量沿着指定维度拼接(concatenate)的函数。它广泛应用于深度学习中,例如在神经网络中合并特征、拼接不同维度的数据等场景。
功能与用途
torch.cat
的主要功能是将多个张量沿着指定的维度拼接成一个新的张量。拼接的张量在其他维度上必须具有相同的形状,否则会报错。
函数签名
torch.cat(tensors, dim=0, *, out=None) → Tensor
参数说明
-
tensors
(sequence of Tensors):需要拼接的张量序列,可以是列表(list
)或元组(tuple
)。这些张量必须具有相同的形状,除了拼接维度外。 -
dim
(int, optional):指定拼接的维度,默认值为0
。例如:-
dim=0
表示沿着第一个维度(行)拼接。 -
dim=1
表示沿着第二个维度(列)拼接。
-
-
out
(Tensor, optional):可选参数,用于指定输出张量。如果提供,结果将存储在该张量