当前位置: 首页 > news >正文

【pyTorch】关于PyTorch的高级索引机制理解

y_hat[[0, 1], y]

看起来简短,其实包含了 PyTorch 的高级索引(advanced indexing) 机制,
常用于分类任务中,从模型输出中取出正确类别对应的预测概率(或分数)

我们来一步步拆解 👇


🧩 一、先看变量定义

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6],[0.3, 0.2, 0.5]
])

可以理解为:

样本编号模型预测概率(3类)正确类别
第 0 个样本[0.1, 0.3, 0.6]0
第 1 个样本[0.3, 0.2, 0.5]2

🧮 二、代码含义

y_hat[[0, 1], y]

这是一个行列双索引操作,等价于:

torch.tensor([y_hat[0, y[0]], y_hat[1, y[1]]])

也就是:

= [y_hat[0, 0], y_hat[1, 2]]

📊 三、一步步替换求值

从矩阵中取出对应元素:

索引取出的元素
y_hat[0, 0]第 0 行第 0 列0.1
y_hat[1, 2]第 1 行第 2 列0.5

所以:

y_hat[[0, 1], y] = tensor([0.1, 0.5])

🧠 四、这个操作在深度学习中的用途

这个技巧非常常见!尤其是在 交叉熵损失 (cross-entropy) 计算中。

假设 y_hat 是每个样本对所有类别的预测概率分布,y 是正确标签,
那这行代码可以提取出:

每个样本 “对应正确类别” 的预测概率。

例如:

y_hat[[0, 1], y]
# → tensor([预测第0个样本属于真实类别的概率, 
#           预测第1个样本属于真实类别的概率])

然后就可以算交叉熵:

loss = -torch.log(y_hat[[0, 1], y])

✅ 五、总结一句话

代码含义
y_hat[[0, 1], y]按行索引 [0, 1] 和对应列索引 y 同时取元素
= [y_hat[0, y[0]], y_hat[1, y[1]]]等价的展开写法
结果取出每个样本的正确类别预测值(常用于分类任务)

💡 口诀记忆:

y_hat[行索引, 列索引]
会“对齐”地从每一行中取出对应列的元素。
常用来拿到“每个样本在真实类别上的预测概率”。


具体理解

🧩 数据回顾

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6],[0.3, 0.2, 0.5]
])

🎨 图示说明

          类别0    类别1    类别2
样本0 →   0.1      0.3      0.6
样本1 →   0.3      0.2      0.5

以及标签:

y = [0, 2]

表示:

  • 样本 0 的真实类别是 0
  • 样本 1 的真实类别是 2

🔍 执行这句代码:

y_hat[[0, 1], y]

等价于:

取第 0 行的第 y[0]=0 列 → 0.1
取第 1 行的第 y[1]=2 列 → 0.5

✅ 可视化标注结果

          类别0    类别1    类别2
样本0 →  [0.1]*    0.3      0.6
样本1 →   0.3      0.2    [0.5]*星号 * 表示被选中的元素

最终输出:

tensor([0.1, 0.5])

🧠 用途回顾

在分类任务里:

loss = -torch.log(y_hat[[0, 1], y])

就是取出模型对真实标签类别的预测概率,再取负对数计算交叉熵损失。

http://www.dtcms.com/a/461503.html

相关文章:

  • c++ bug 函数定义和声明不一致导致出bug
  • 网站建设需求分析文档手机上制作ppt的软件
  • 推广网站怎么做能增加咨询南宁企业官网seo
  • MATLAB的无线传感器网络(WSN)算法仿真
  • k8s opa集成
  • Nginx 负载均衡通用方案
  • 我的世界怎么做神器官方网站dw网站设计与制作
  • ubuntu22.04发布QT程序步骤
  • Spring Boot:分布式事务高阶玩法
  • 做网站开什么端口网址格式
  • 白云区建设局网站建筑工程网教
  • react native android设置邮箱,进行邮件发送
  • Java面试场景:从Spring Boot到Kubernetes的技术问答
  • 从潜在空间到实际应用:Embedding模型架构与训练范式的综合解析
  • Vue3 provide/inject 详细组件关系说明
  • php的网站架构建设框架嘉兴网站设计
  • Redis(四)——Redis主从同步与对象模型
  • 2016年网站建设总结培训学校
  • 网站最下端怎么做动画设计培训机构
  • 用python制作相册浏览小工具
  • 字节跳动ByteDance前端考前总结
  • codex使用chrome-devtools-mcp最佳实践
  • 【Linux命令从入门到精通系列指南】export 命令详解:环境变量管理的核心利器
  • python 自动化采集 ChromeDriver 安装
  • 苏州招聘网站建设推广费
  • java8提取list中对象有相同属性值的对象或属性值
  • cuda编程笔记(26)-- 核函数使用任务队列
  • 存储芯片核心产业链研发实力:兆易创新、北京君正、澜起科技、江波龙、长电科技、佰维存储,6家龙头公司研发实力深度数据
  • 《Seq2Time: Sequential Knowledge Transfer for Video LLMTemporal Grounding》
  • 山东省建设部网站官网网站备案审核通过后