当前位置: 首页 > news >正文

torch.gather()和torch.sort

torch.gather()

def semantic_neighbor(x, index):
'''
假设x.shape=[B,L,C]=[2,3,4]   index.shape=[B,L]=[2,3]
x = torch.tensor([[[1, 2, 3, 4],    # 样本1的3个元素,每个元素4维特征[5, 6, 7, 8],[9, 10, 11, 12]],[[13, 14, 15, 16], # 样本2的3个元素[17, 18, 19, 20],[21, 22, 23, 24]]
])# 索引张量 index (B=2, L=3)
index = torch.tensor([[1, 0, 1],  # 样本1的重组索引[2, 1, 0]   # 样本2的重组索引
])'''dim = index.dim()#dim=2assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)for _ in range(x.dim() - index.dim()):index = index.unsqueeze(-1)'''x.index=[2,3]index = torch.tensor([[[1],[0], [1]], [[2], [1], [0]]  ])'''index = index.expand(x.shape)'''x.index=[2,3,4]index = torch.tensor([[[1,1,1,1],[0,0,0,0], [1,1,1,1]], [[2,2,2,2], [1,1,1,1], [0,0,0,0]]  ])'''shuffled_x = torch.gather(x, dim=dim - 1, index=index)'''tensor([[[ 5,  6,  7,  8],  # 来自原始位置1[ 1,  2,  3,  4],  # 来自原始位置0[ 5,  6,  7,  8]], # 来自原始位置1[[21, 22, 23, 24],  # 来自原始位置2[17, 18, 19, 20],  # 来自原始位置1[13, 14, 15, 16]]  # 来自原始位置0
])'''return shuffled_x'''
另一个简单的示例:
源张量(3x4矩阵)
x = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])索引张量(2x3矩阵)
index = torch.tensor([[0, 1, 2],[2, 1, 0]])沿dim=0(行方向)收集
out = torch.gather(x, dim=0, index=index)结果:
[[1,  6, 11],  # 取x[0][0], x[1][1], x[2][2][9,  6,  3]]  # 取x[2][0], x[1][1], x[0][2]]
'''

x.sort()
x_sort_values, x_sort_indices = torch.sort(detached_index, dim=-1, stable=False)

  • torch.sort:对 detached_index 沿 dim=-1(即 n 维度)进行排序。
  • detached_index=[[2,0,1,0]]那么detached_index 排序后的值是 [[0, 0, 1, 2]](即 x_sort_values)。
  • x_sort_indices[[1, 3, 2, 0]],表示:
    • 排序后的第0个元素来自原始位置1(值是0),
    • 第1个元素来自原始位置3(值是0),
    • 第2个元素来自原始位置2(值是1),
    • 第3个元素来自原始位置0(值是2)。

相关文章:

  • Human DiO-LDL,绿色荧光标记人源低密度脂蛋白,研究细胞内吞
  • vscode include总是报错
  • 印度语言指令驱动的无人机导航!UAV-VLN:端到端视觉语言导航助力无人机自主飞行
  • nltk-英文句子分词+词干化
  • 如何顺利地将应用程序从 Android 转移到Android
  • 微服务架构中的 RabbitMQ:异步通信与服务解耦(一)
  • 第六部分:阶段项目 5:构建 NestJS RESTful API 服务器
  • 5G 网络全场景注册方式深度解析:从信令交互到报文分析
  • Day124 | 灵神 | 二叉树 | 二叉树最小深度
  • 什么是VR展馆?VR展馆的实用价值有哪些?
  • 110kV/630mm2电缆5km的交流耐压试验兼顾110kVGIS开关用
  • jquery.table2excel方法导出
  • Cause: org.apache.ibatis.ognl.OgnlException: sqlSegment
  • 新手到资深的Java开发编码规范
  • 游戏如何应对反编译工具dnspy
  • b/s开发 1.0
  • C++ JSON解析技术详解
  • YOLOv11 性能评估与横向对比
  • pdf图片导出(Visio和Origin)
  • X82Y文字aI连线验证码
  • 网站建设费 无形资产/软文推广页面
  • 写作网站vir/北京网站优化托管
  • 电子商务系统 网站建设/湘潭网站建设
  • 莱芜网站优化/太仓网站制作
  • 网销怎么找客户/网站优化流程
  • 网站设计 推广/seo优化是利用规则提高排名