PyTorch学习之:torch.gather是什么?
torch.gather的定义:
torch.gather
是 PyTorch 中的一个张量操作函数,其作用是根据指定的维度(dim
)和索引张量(index
),从输入张量(input
)中收集元素,生成一个与索引张量形状相同的输出张量。总体来说,就是维度dim和索引张量index决定一个收集数的规则,然后,基于这个规则从输入张量中获取需要的元素。
核心部分:
1.输入张量(input
):
- 任意形状的张量。
2.索引张量(index
):
-
形状必须与输入张量在除
dim
外的其他维度上一致。 -
索引值必须在输入张量
dim
维度的有效范围内(即0
到size(dim)-1
)。
3.输出张量(output
):
-
形状与索引张量相同。
-
每个元素的值由以下规则确定:
output[i][j][k] = input[i][index[i][j][k]][k] # 当 dim=1 时
举例详解:
示例 1:二维张量,dim=1
import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=1, index=index)
print(output)
输出:
tensor([[1, 1],[4, 3]])
解释:
输入是一个2x2的矩阵,因为dim是1,所以我们参考下面的公式:
output[i][j] = input[i][index[i][j]] # 当 dim=1 时
对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][1](即为4)。
对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[1][0](即为3)。
所以,最后的结果为:
tensor([[1, 1],[4, 3]])
示例 2:二维张量,dim=0
import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=0, index=index)
print(output)
输出:
tensor([[1, 2],[3, 2]])
解释:
输入是一个2x2的矩阵,因为dim是0,所以我们参考下面的公式:
output[i][j] = input[index[i][j]][j] # 当 dim=0 时
对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][1](即为2)。
对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][0](即为3)。
对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[0][1](即为2)。
所以,最后的结果为:
tensor([[1, 2],[3, 2]])