当前位置: 首页 > news >正文

PyTorch中torch.eq()、torch.argmax()函数的详解和代码示例

下面对 PyTorch 中常用的两个函数 torch.eq()torch.argmax()详解,包括 功能、参数说明、返回值、注意事项 以及 代码示例


一、torch.eq(input, other)

功能:

按元素逐位判断 input == other,返回 bool 类型的张量(值为 TrueFalse)。

参数:

参数类型含义
inputTensor被比较的第一个张量
otherTensor 或数值被比较的第二个张量或标量

要求 inputother 的维度可广播(broadcastable)。

返回值:

一个 布尔型张量,形状与广播后的 input 相同。

示例 1:两个张量比较

import torcha = torch.tensor([1, 2, 3])
b = torch.tensor([1, 0, 3])
result = torch.eq(a, b)
print(result)
# 输出: tensor([True, False, True])

示例 2:张量与标量比较

a = torch.tensor([[3, 5], [5, 5]])
result = torch.eq(a, 5)
print(result)
# 输出: tensor([[False, True], [True, True]])

二、torch.argmax(input, dim=None, keepdim=False)

功能:

返回 input 张量在指定维度上最大值的索引(索引位置)。

参数:

参数类型说明
inputTensor输入张量
dimintNone指定在哪个维度上返回最大值索引。默认 None,表示返回 扁平化后的最大值索引
keepdimbool是否保持维度。默认为 False。若为 True,输出将保留原始维度大小为1。

返回值:

张量索引(int64),表示每个位置上最大值的索引。


示例 1:不指定 dim(返回扁平化最大索引)

a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a)
print(idx)
# 输出: tensor(1)(值 9 在展平后第1个位置)

示例 2:指定 dim=1(按行取最大值索引)

a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a, dim=1)
print(idx)
# 输出: tensor([1, 1]) (每行最大值索引)

示例 3:保留维度

a = torch.tensor([[1, 9], [3, 7]])
idx = torch.argmax(a, dim=1, keepdim=True)
print(idx)
# 输出: tensor([[1], [1]])

注意事项对比

函数名返回类型说明
torch.eq()BoolTensor返回逐元素相等的布尔张量
torch.argmax()IntTensor返回指定维度或展平后最大值的位置索引

实战小技巧:

判断两个结果张量是否完全相同:

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
if torch.eq(a, b).all():print("完全相同")

多分类预测取最大概率类:

logits = torch.tensor([[0.1, 0.3, 0.6],[0.2, 0.5, 0.3]])
pred = torch.argmax(logits, dim=1)
print(pred)  # 输出: tensor([2, 1])

http://www.dtcms.com/a/272487.html

相关文章:

  • 多线程交替打印ABC
  • Windows安装DevEco Studio
  • 解决问题:在cmd中能查看到pnpm版本,在vscode终端中却报错
  • [5种方法] 如何将iPhone短信保存到电脑
  • 搜索算法在前端的实践
  • G5打卡——Pix2Pix算法
  • Vue前端导出页面为PDF文件
  • 【HDLBits习题 2】Circuit - Sequential Logic(4)More Circuits
  • AI驱动的业务系统智能化转型:从静态配置到动态认知的范式革命
  • 基础 IO
  • Spring Boot中的中介者模式:终结对象交互的“蜘蛛网”困境
  • JAVA JVM的内存区域划分
  • Redis的常用命令及`SETNX`实现分布式锁、幂等操作
  • Redis Stack扩展功能
  • K8S数据流核心底层逻辑剖析
  • AI进化论06:连接主义的复兴——神经网络的“蛰伏”与“萌动”
  • k8s集群--证书延期
  • 项目进度管控依赖Excel,如何提升数字化能力
  • 调度器与闲逛进程详解,(操作系统OS)
  • UI前端与数字孪生结合案例分享:智慧城市的智慧能源管理系统
  • 数据结构笔记10:排序算法
  • Windows 本地 使用mkcert 配置HTTPS 自签名证书
  • Java并发 - 阻塞队列详解
  • XSS(ctfshow)
  • 文心大模型4.5开源测评:保姆级部署教程+多维度测试验证
  • 图书管理系统(完结版)
  • PyCharm 中 Python 解释器的添加选项及作用
  • 创始人IP如何进阶?三次关键突破实现高效转化
  • QT解析文本框数据——详解
  • pycharm中自动补全方法返回变量