当前位置: 首页 > news >正文

torch.argsorttorch.gather

文章目录

  • 1. 举例说明
  • 2. pytorch 代码

1. 举例说明

torch.argsort 的作用是可以将矩阵中的元素进行从小到大排序,得到对应的序号。假设我们有一个向量a表示如下
a = [ 8 , 7 , 6 , 9 , 7 ] \begin{equation} a=[8,7,6,9,7] \end{equation} a=[8,7,6,9,7]
那么从小到大可以得到排序向量为b
b = [ 2 , 1 , 4 , 0 , 3 ] \begin{equation} b=[2,1,4,0,3] \end{equation} b=[2,1,4,0,3]
如果我想通过序号向量b来直接从小到大排序的向量c,那么就需要torch.gather函数
c = [ 6 , 7 , 7 , 8 , 9 ] \begin{equation} c=[6,7,7,8,9] \end{equation} c=[6,7,7,8,9]

2. pytorch 代码

  • python 代码描述:
import torch
torch.manual_seed(23231)

torch.set_printoptions(precision=3, sci_mode=False)
# torch.seed()
if __name__ == "__main__":
    run_code = 0
    a_vector =torch.randint(low=1,high=10,size=(5,))
    print(f"a_vector=\n{a_vector}")
    a_argsort = torch.argsort(input=a_vector)
    print(f"a_argsort=\n{a_argsort}")
    a_restore = torch.argsort(a_argsort)
    print(f"a_restore=\n{a_restore}")
    a_gather = torch.gather(input=a_vector, dim=0, index=a_argsort)
    print(f"a_gather={a_gather}")
    a_matrix = torch.randint(0, 10, (3, 4))
    matrix_argsort = torch.argsort(input=a_matrix, dim=1)
    print(f"a_matrix=\n{a_matrix}")
    print(f"matrix_argsort=\n{matrix_argsort}")
    matrix_gather = torch.gather(input=a_matrix,dim=1,index=matrix_argsort)
    print(f"matrix_gather=\n{matrix_gather}")
  • result:
a_vector=
tensor([8, 7, 6, 9, 7])
a_argsort=
tensor([2, 1, 4, 0, 3])
a_restore=
tensor([3, 1, 0, 4, 2])
a_gather=tensor([6, 7, 7, 8, 9])
a_matrix=
tensor([[0, 2, 9, 5],
        [0, 6, 8, 5],
        [0, 8, 3, 7]])
matrix_argsort=
tensor([[0, 1, 3, 2],
        [0, 3, 1, 2],
        [0, 2, 3, 1]])
matrix_gather=
tensor([[0, 2, 5, 9],
        [0, 5, 6, 8],
        [0, 3, 7, 8]])

相关文章:

  • 工程化与框架系列(36)--前端监控告警实践
  • 多任务学习与持续学习微调:深入探索大型语言模型的性能与适应性
  • L2-3 花非花,雾非雾
  • 从FFmpeg命令行到Rust:多场景实战指南
  • StarRocks SQL使用与MySql的差异及规范注意事项
  • 时区转换工具
  • 详细介绍GetDlgItem()
  • TypeScript接口 interface 高级用法完全解析
  • 使用EasyExcel进行简单的导入、导出
  • JxBrowser 8.5.0 版本发布啦!
  • 为什么手机上用 mA 和 mAh 来表示功耗和能耗?
  • MiddleVR for Unity插件
  • S32K144外设实验(一):LPIT的周期中断
  • 【MySQL】MySQL审计工具Audit Plugin安装使用
  • Dify平台离线镜像部署
  • 字母~~~
  • vllm-openai多服务器集群部署AI模型
  • MyBatis SqlSession 是如何创建的? 它与 SqlSessionFactory 有什么关系?
  • V2X验证
  • C#入门学习记录(三)C#中的隐式和显示转换
  • 医学统计专家童新元逝世,终年61岁
  • 河北:开展领导干部任性用权等形式主义官僚主义问题专项整治
  • 金融创新破局记:中小微企业转型背后的金融力量
  • 申花四连胜领跑中超,下轮榜首大战对蓉城将是硬仗考验
  • 文化体验+商业消费+服务创新,上海搭建入境旅游新模式
  • 乌称泽连斯基与特朗普进行简短会谈