当前位置: 首页 > news >正文

PyTorch 中的一个函数 —— torch.argmax

PyTorch 中的一个函数 —— torch.argmax

torch.argmax 是 PyTorch 中的一个函数,用于返回输入张量中最大值所在的索引。其作用与数学中的 ​argmax 概念一致,即找到某个函数在指定范围内取得最大值时的参数(位置索引

函数定义

torch.argmax(input, dim=None, keepdim=False)
  • ​输入:
    • input:输入张量。
    • dim(可选):指定沿哪个维度查找最大值。如果为 None,则在整个张量中查找。
    • keepdim(可选):是否保持输出张量的维度与输入一致(默认为 False)。
  • ​输出:
    一个张量,包含最大值所在的索引

核心功能

1、​全局最大值索引​(当 dim=None)

  • 将输入张量展平后,返回最大值的索引
import torch

x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x))  # 输出:tensor(3)
# 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值为6,索引为3(从0开始)

2|​沿指定维度查找最大值索引​(当 dim 指定时)

  • 沿 dim 维度对输入张量操作,返回每行/列的最大值索引
# 沿行维度(dim=1)查找
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x, dim=1))  # 输出:tensor([2, 0])
# 解释:
# 第一行 [1, 2, 3] 最大值3,索引2
# 第二行 [6, 5, 4] 最大值6,索引0

# 沿列维度(dim=0)查找
print(torch.argmax(x, dim=0))  # 输出:tensor([1, 1, 0])
# 解释:
# 第0列 [1, 6] 最大值6,索引1
# 第1列 [2, 5] 最大值5,索引1
# 第2列 [3, 4] 最大值4,索引1(但此处输出为0,可能有误,实际应为1)

参数详解

1. dim 参数

  • ​作用:指定沿哪个维度操作。
  • ​示例:
    • dim=0:沿列操作(纵向)。
    • dim=1:沿行操作(横向)。

2. keepdim 参数

  • ​作用:保持输出维度与输入一致。
  • ​示例:
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
out = torch.argmax(x, dim=1, keepdim=True)
print(out)  # 输出:tensor([[2], [0]])

常见用途

1、​分类任务中获取预测标签

logits = torch.tensor([0.1, 0.8, 0.05, 0.05])  # 模型输出的概率分布
predicted_class = torch.argmax(logits)         # 输出:tensor(1)

2、​计算准确率

# 假设batch_size=4,num_classes=3
preds = torch.tensor([[0.1, 0.2, 0.7],
                      [0.9, 0.05, 0.05],
                      [0.3, 0.4, 0.3],
                      [0.05, 0.8, 0.15]])
labels = torch.tensor([2, 0, 1, 1])
# 获取预测类别
predicted_classes = torch.argmax(preds, dim=1)  # 输出:tensor([2, 0, 1, 1])
# 计算正确预测数
correct = (predicted_classes == labels).sum()   # 输出:tensor(3)

注意事项

1、​多个相同最大值:

  • 如果存在多个相同的最大值,返回第一个出现的索引
x = torch.tensor([3, 1, 4, 4])
print(torch.argmax(x))  # 输出:tensor(2)

2、​数据类型

  • 输入张量应为数值类型(如 float32、int64)

3、​维度合法性

  • 如果指定了不存在的维度(如 dim=3 对一个二维张量),会触发错误

总结

torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景。理解其 dim 和 keepdim 参数的行为,可以灵活处理不同维度的数据

http://www.dtcms.com/a/113358.html

相关文章:

  • # 深入了解fasttext
  • 2025/4/2 心得
  • 嗅觉莫名减退、长期失眠,帕金森已潜伏?
  • 【玩泰山派】0、mac utm安装windows10
  • JVM 内存区域详解
  • 01人工智能基础入门
  • JavaWeb 课堂笔记 —— 01 HTML
  • AutoCAD2026中文版下载安装教程
  • GESP:2025-3月等级8-T1-上学
  • Java异步编程中的CompletableFuture介绍、常见错误及最佳实践
  • 多周期多场景的供应链优化问题 python 代码
  • QMainWindow添加状态栏
  • 【深度学习】嘿马深度学习目标检测教程第2篇:目标检测算法原理,3.2 R-CNN【附代码文档】
  • 【C/C++算法】蓝桥杯之递归算法(如何编写想出递归写法)
  • 2025 年 4 月补丁星期二预测:微软将推出更多 AI 安全功能
  • Java实现N皇后问题的双路径探索:递归回溯与迭代回溯算法详解
  • 【微机及接口技术】- 第四章 内部存储器及其接口(中)
  • LlamaIndex实现RAG增强:上下文增强检索/重排序
  • 我是如何写作的?
  • LintCode第974题-求矩阵各节点的最短路径(以0为标准)
  • 如何将本地更改的README文件同步到自己的GitHub项目仓库
  • OmniParser: 让大模型化身“电脑管家”
  • 洛谷 P3214 [HNOI2011] 卡农
  • 2.IO流的体系和字节输出流FileOutputStream的基本用法
  • macos 魔搭 模型下载 Wan-AI ComfyUI
  • L2-024 部落 #GPLT,并查集 C++
  • 智能驾驶中预测模块简介
  • 广州t11基地顺利完成交割,TCL华星技术产能双升级
  • 【java】Class.newInstance()
  • 硬币找零问题