当前位置: 首页 > news >正文

PyTorch topk() 用法详解:取最大值

torch.topk(input, k) 返回张量中最大的 k 个元素以及它们在原张量中的 索引

函数原型

torch.topk(input, k, dim=None, largest=True, sorted=True)

参数说明:

参数说明
input输入张量
k要取出的前 k 个值
dim指定沿哪个维度取值(默认是最后一维)
largest是否取最大值(默认是 True,为 False 时返回最小值)
sorted返回的结果是否排序(默认是 True,按值从大到小)

示例:二维张量中使用 topk()dim=0 vs dim=1

我们来通过一个具体的 3x3 张量示例,观察在不同维度上使用 topk() 的结果。

import torch# 创建一个 3x3 的二维张量
x = torch.tensor([[0.1, 0.8, 0.6],[0.9, 0.2, 0.3],[0.5, 0.4, 0.7]
])

沿行取 Top-k:dim=1

print(torch.topk(x, k=2, dim=1))# 输出:
# values=tensor([[0.8000, 0.6000],
#         [0.9000, 0.3000],
#         [0.7000, 0.5000]]),
# indices=tensor([[1, 2],
#         [0, 2],
#         [2, 0]]))

每一行分别取出前两个最大值及其列索引

沿列取 Top-k:dim=0

print(torch.topk(x, k=2, dim=0))#  输出:
# values=tensor([[0.9000, 0.8000, 0.7000],
#        [0.5000, 0.4000, 0.6000]]),
# indices=tensor([[1, 0, 2],
#        [2, 2, 0]]))

每一列分别取出前两个最大值及其对应的“行号”。

理解维度的直觉图示

  • dim=1按行取 top-k(对每一行,从左往右选 k 个最大值)
  • dim=0按列取 top-k(对每一列,从上往下选 k 个最大值)
操作意图方向
topk(x, k, dim=1)每行选前 k 个最大⟶ 横向
topk(x, k, dim=0)每列选前 k 个最大⬇ 纵向

topk 与largest、sorted操作的组合

1. 取最小值:largest=False
d = torch.tensor([5, 3, 8, 1, 2])
smallest, indices = torch.topk(d, k=2, largest=False)
print("前2小的值:", smallest) 
# 输出: tensor([1, 2])

2. 不排序:sorted=False
e = torch.tensor([3, 1, 4, 2, 5])
values, indices = torch.topk(e, k=3, sorted=False)
print("前3大的值(未排序):", values)  
# 输出: tensor([3, 4, 5])
print("对应索引:", indices)         
# 输出: tensor([0, 2, 4])

相关文章:

  • 织梦网站维护html网页制作用什么软件
  • 东莞b2b网站建设seo优化个人博客
  • 泰安网站建设入门推荐企业宣传片视频
  • 网站建设的类型或分类百度识图在线识别
  • 建设一个营销型网站网站的推广方法
  • 宝山青岛网站建设web网站设计
  • CI/CD GitHub Actions配置流程
  • mongoose解析http字段值
  • 【LLaMA-Factory 实战系列】三、命令行篇 - YAML 配置与高效微调 Qwen2.5-VL
  • 走近科学IT版:FreeBSD系统下ThinkPad键盘突然按不出b和n键了!
  • Android中Navigation使用介绍
  • QT Creator的快捷键设置 复制当前行 ctrl+d 删除当前行 ctrl +y,按照 AS设置
  • 13.5-13.8. 计算机视觉【2】
  • jar 包如何下载
  • 网页变形记:响应式设计如何在手机里 “七十二变”
  • 【unitrix】 4.3 左移运算(<<)的实现(shl.rs)
  • 医疗AI数智立体化体系V2.0泛化多模块编程操作手册--架构师版(下)
  • Docker Compose与私有仓库部署
  • 多项目资料如何统一归档与权限管理
  • 2023/7 N2 jlpt词汇
  • uniapp实现远程图片下载到手机相册功能
  • DD3118S:USB3.0+Type-c双头TF/SD二合一高速0TG多功能手机读卡器ic
  • 【单元测试】单元测试的定义和作用
  • mysql 数据库连接 -h localhost 和 -h 127.0.0.1 区别是什么
  • 【AI时代速通QT】第三节:Linux环境中安装QT并做测试调试
  • C++修炼:异常