torch.gather
torch.gather
介绍
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
沿由 dim 指定的轴收集值。
对于三维张量,输出按如下方式确定:
out[i][j][k] = input[index[i][j][k]][j][k] # 如果 dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # 如果 dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # 如果 dim == 2
input 和 index 必须具有相同的维度数。同时要求对于所有不等于 dim 的维度 d,满足 index.size(d) <= input.size(d)。输出的形状将与 index 相同。注意 input 和 index 不会相互广播。
参数
-
input (Tensor) – 源张量
-
dim (int) – 进行索引的轴
-
index (LongTensor) – 要收集的元素的索引
关键字参数
-
sparse_grad (bool, 可选) – 如果为 True,则关于 input 的梯度将是一个稀疏张量。
-
out (Tensor, 可选) – 目标张量
示例:
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1], [ 4, 3]])
举例
其实torch文档给的形式非常清晰,只是一上来可能不太好理解
假如input是一个shape=[2,2]的矩阵,此时dim只能等于0或者1,index的shape也只能大于或者等于[2,2]
input=torch.tensor([[1,2][3,4]])
index = torch.tensor([[0, 1], [1, 2]])output = torch.gather(input,dim=0,index)
output[[],[]]
上面dim=0表示 output[i][j] = t[ index[i][j] ][ j ]
意思新的output矩阵行索引取值input矩阵的行索引,列索引取index矩阵中的元素值
所以取值如下
[[input[0,0],input[0,1]],[input[1,1],input[1,2]]
]
[1,23,4
]
总结
将index矩阵中的元素当成对input取值的行索引或者列索引,同时注意index矩阵中的元素值不能超过input的行或者列大小, 比如dim=0,那么index中元素值不能超过input的列大小2,否则就会报错