PyTorch_张量拼接
张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如:残差网络,注意力机制中都使用张量拼接。
torch.cat 函数的使用
可以将两个张量根据指定的维度拼接起来。
import torch
import numpy as np def test01():data1 = torch.randint(0, 10, [3, 4, 5])data2 = torch.randint(0, 10, [3, 4, 5])print(data1.shape)print(data2.shape)# dim 对应的值可以是负数,可以通过list来思考# 按照第 0 维度进行拼接new_data = torch.cat([data1, data2], dim = 0) # 是列表print(new_data.shape)# 按照第 1 维度进行拼接new_data = torch.cat([data1, data2], dim = 1)print(new_data.shape)# 按照第 2 维度进行拼接new_data = torch.cat([data1, data2], dim = 2)print(new_data.shape)if __name__ == "__main__":test01()
torch.stack 函数的使用
torch.stack 函数可以将两个张量根据指定的维度叠加起来,或者组合成新的元素。叠加
的意思:当两个元素叠在一起,我们就将这两个元素当作一个元素。
import torch
import numpy as np def test01():data1 = torch.randint(0, 10, [2, 3])data2 = torch.randint(0, 10, [2, 3])print(data1)print(data2)# 将两个张量 stack 叠加起来,像 cat 一样指定维度# 1. 按照第0维度进行叠加new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)# 2. 按照第1维度进行叠加new_data = torch.stack([data1, data2], dim=1)print(new_data)# 3. 按照第2维度进行叠加new_data = torch.stack([data1, data2], dim=2)print(new_data)if __name__ == "__main__":test01()