tensor连接和拆分
文章目录
- 连接
 - torch.cat()
 - 案例准备
 
- torch.stack()
 - 区别
 
- 拆分
 - torch.split()
 
连接
torch.cat()
函数目的: 在给定维度上对输入的张量序列 进行连接操作。
案例准备
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
b = torch.tensor([[10,10,10,],[10,10,10],[10,10,10,]], dtype=torch.float)
 

# dim指的是维度,dim = 0就是行,所以下面的代码就是按行拼接
print("按行拼接:\n",torch.cat((a,b),dim=0))
print("按行拼接:\n",torch.cat((a,b),dim=0).shape) #6行3列
 

print("按列拼接:\n",torch.cat((a,b),dim=1))
print("按列拼接:\n",torch.cat((a,b),dim=1).shape)#3行6列
 

torch.stack()
沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
 也就是2维拼成3维,3维拼4维,以此类推。
print("按行拼接:\n",torch.stack((a,b),dim=0))
print("按行拼接:\n",torch.stack((a,b),dim=0).shape) 
 

print("按行拼接:\n",torch.stack((a,b),dim=1))
print("按行拼接:\n",torch.stack((a,b),dim=1).shape)
 

print("按行拼接:\n",torch.stack((a,b),dim=2))
print("按行拼接:\n",torch.stack((a,b),dim=2).shape)
 

区别
stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。
c = torch.tensor([[10,20],[30,40],[50,60]], dtype=torch.float)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
torch.cat((a,c),dim=1)
 

#但是以下情况就会出错
torch.cat((a,c),dim=0)
 

 如图,按行拼接会缺数据,报错吗,应该的。
 
torch.stack((a,c),dim=0)
###运行结果
RuntimeError: stack expects each tensor to be equal size, but got [3, 3] at entry 0 and [3, 2] at entry 1
再次验证stack需要两个大小一样的张量
 
拆分
torch.split()
def split(
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, …]:
- 按块大小拆分张量 除不尽的取余数,返回一个元组
 
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
print(torch.split(a,2,dim=0))	#按行拆,两行拆成一个
print(torch.split(a,1,dim=0))	#按行拆,一行拆成一个
print(torch.split(a,1,dim=1))	#按列拆,一列拆成一个
print(torch.split(a,2,dim=1)) 	#按列拆,两列拆成一个
 

- 按块数拆分张量
 
torch.chunk(a,2,dim=0)	#按行拆成两块
torch.split(a,2,dim=1)	#按列拆成两块
 

