当前位置: 首页 > news >正文

PyTorch学习之:torch.gather是什么?

torch.gather的定义:

torch.gather 是 PyTorch 中的一个张量操作函数,其作用是根据指定的维度dim)和索引张量index),从输入张量(input)中收集元素,生成一个与索引张量形状相同的输出张量。总体来说,就是维度dim和索引张量index决定一个收集数的规则,然后,基于这个规则从输入张量中获取需要的元素。

核心部分:

1.输入张量input):

  • 任意形状的张量。

2.索引张量index):

  • 形状必须与输入张量在除 dim 外的其他维度上一致。

  • 索引值必须在输入张量 dim 维度的有效范围内(即 0 到 size(dim)-1)。

3.输出张量output):

  • 形状与索引张量相同。

  • 每个元素的值由以下规则确定:

output[i][j][k] = input[i][index[i][j][k]][k]  # 当 dim=1 时

举例详解:

示例 1:二维张量,dim=1

import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=1, index=index)
print(output)

输出

tensor([[1, 1],[4, 3]])

 解释

输入是一个2x2的矩阵,因为dim是1,所以我们参考下面的公式:

output[i][j] = input[i][index[i][j]]  # 当 dim=1 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][1](即为4)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[1][0](即为3)。

所以,最后的结果为:

tensor([[1, 1],[4, 3]])

 示例 2:二维张量,dim=0

import torchinput = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)output = torch.gather(input, dim=0, index=index)
print(output)

输出

tensor([[1, 2],[3, 2]])

 解释

输入是一个2x2的矩阵,因为dim是0,所以我们参考下面的公式:

output[i][j] = input[index[i][j]][j]  # 当 dim=0 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][1](即为2)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][0](即为3)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[0][1](即为2)。

所以,最后的结果为:

tensor([[1, 2],[3, 2]])

相关文章:

  • MBSS-T1:基于模型的特定受试者自监督运动校正方法用于鲁棒心脏 T1 mapping|文献速递-深度学习医疗AI最新文献
  • InetAddress 类详解
  • 第一章 Proteus中Arduino的可视化程序
  • 宁夏建设工程专业技术职称评审条件
  • 今日行情明日机会——20250521
  • 掩膜合并代码
  • 关于TCP三次握手
  • Java异步编程利器:CompletableFuture 深度解析与实战
  • 5.21本日总结
  • 端口号详解(技术向)
  • 轩辕杯Wp
  • 从运维告警到业务决策:可观测性正在重新定义企业数据基础设施
  • AI工程师系列——面向copilot编程
  • 配电网运行状态综合评估方法研究
  • 使用 mutt 发送邮件:Linux 下轻量高效的命令行邮件工具
  • NV009NV010美光闪存颗粒NV011NV012
  • Java面试问题基础篇
  • BISS0001 PIR红外感应IC:高性能热释电信号处理解决方案
  • DNS服务搭建与配置详解
  • JS手写代码篇---手写Promise
  • 网店美工课本/抚顺优化seo
  • 法律建设网站/怎么创建网页链接
  • 网站建设课程简介图片/百度网站登录入口