torch.gather()和torch.sort
torch.gather()
def semantic_neighbor(x, index):
'''
假设x.shape=[B,L,C]=[2,3,4] index.shape=[B,L]=[2,3]
x = torch.tensor([[[1, 2, 3, 4], # 样本1的3个元素,每个元素4维特征[5, 6, 7, 8],[9, 10, 11, 12]],[[13, 14, 15, 16], # 样本2的3个元素[17, 18, 19, 20],[21, 22, 23, 24]]
])# 索引张量 index (B=2, L=3)
index = torch.tensor([[1, 0, 1], # 样本1的重组索引[2, 1, 0] # 样本2的重组索引
])'''dim = index.dim()#dim=2assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)for _ in range(x.dim() - index.dim()):index = index.unsqueeze(-1)'''x.index=[2,3]index = torch.tensor([[[1],[0], [1]], [[2], [1], [0]] ])'''index = index.expand(x.shape)'''x.index=[2,3,4]index = torch.tensor([[[1,1,1,1],[0,0,0,0], [1,1,1,1]], [[2,2,2,2], [1,1,1,1], [0,0,0,0]] ])'''shuffled_x = torch.gather(x, dim=dim - 1, index=index)'''tensor([[[ 5, 6, 7, 8], # 来自原始位置1[ 1, 2, 3, 4], # 来自原始位置0[ 5, 6, 7, 8]], # 来自原始位置1[[21, 22, 23, 24], # 来自原始位置2[17, 18, 19, 20], # 来自原始位置1[13, 14, 15, 16]] # 来自原始位置0
])'''return shuffled_x'''
另一个简单的示例:
源张量(3x4矩阵)
x = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])索引张量(2x3矩阵)
index = torch.tensor([[0, 1, 2],[2, 1, 0]])沿dim=0(行方向)收集
out = torch.gather(x, dim=0, index=index)结果:
[[1, 6, 11], # 取x[0][0], x[1][1], x[2][2][9, 6, 3]] # 取x[2][0], x[1][1], x[0][2]]
'''
x.sort()
x_sort_values, x_sort_indices = torch.sort(detached_index, dim=-1, stable=False)
torch.sort
:对detached_index
沿dim=-1
(即n
维度)进行排序。- 若
detached_index=[[2,0,1,0]]
那么detached_index
排序后的值是[[0, 0, 1, 2]]
(即x_sort_values
)。 x_sort_indices
是[[1, 3, 2, 0]]
,表示:- 排序后的第0个元素来自原始位置1(值是0),
- 第1个元素来自原始位置3(值是0),
- 第2个元素来自原始位置2(值是1),
- 第3个元素来自原始位置0(值是2)。