当前位置: 首页 > news >正文

PyTorch使用(5)-张量索引操作

文章目录

  • 1. 简单行、列索引
    • 1.1. 基础用法
    • 1.2. 工程实践要点
  • 2. 列表索引
    • 2.1. 基础用法
    • 2.2. 高级用法
    • 2.3. 性能考虑
  • 3. 范围索引
    • 3.1. 基础用法
    • 3.2. 高级技巧
    • 3.3. 内存特性
  • 4. 布尔索引
    • 4.1. 基础用法
    • 4.2. 高级用法
    • 4.3. 性能注意事项
  • 5. 多维索引
    • 5.1. 基础用法
    • 5.2. 高级模式
    • 5.3. 工程实践
  • 6. 综合性能比较
  • 7. 最佳实践建议

1. 简单行、列索引

简单的行、列索引是最基本的索引操作,通过整数来访问张量中的元素。可以使用类似数组索引的方式来操作。

1.1. 基础用法

import torch

# 创建一个3x4的矩阵
x = torch.tensor([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12]])

# 获取第2行(索引从0开始)
row = x[1]  # tensor([5, 6, 7, 8])

# 获取第3列
col = x[:, 2]  # tensor([3, 7, 11])

1.2. 工程实践要点

内存视图:简单索引返回的是原张量的视图,不复制数据

性能:O(1)时间复杂度,是最快的索引方式

GPU兼容:在CUDA张量上同样高效

# 获取连续多行/多列
rows = x[1:3]  # 第2-3行
cols = x[:, 1:3]  # 第2-3列

2. 列表索引

列表索引是通过一个列表或数组来选择张量中的多个元素。这种索引方式可以选择多个位置的元素,并返回一个新的张量。

2.1. 基础用法

# 使用列表选择特定行
selected_rows = x[[0, 2]]  # 第1和第3行

# 使用列表选择特定列
selected_cols = x[:, [1, 3]]  # 第2和第4列

# 选择特定元素
elements = x[[0, 1, 2], [1, 2, 3]]  # (0,1), (1,2), (2,3)位置的元素

2.2. 高级用法

# 创建索引张量(比Python列表更高效)
indices = torch.tensor([0, 2], device=x.device)
selected = x[indices]  # 第1和第3行

# 组合行列索引
x[[[0], [2]], [1, 3]]  # 第1/3行的第2/4列 → 2x2矩阵

2.3. 性能考虑

内存开销:列表索引会创建新张量,复制数据

替代方案:对于连续索引,优先使用切片

GPU优化:将索引张量放在与数据相同的设备上

3. 范围索引

范围索引允许你选择张量的一个切片,类似于 Python 列表的切片操作。通过起始索引和结束索引来选择一段连续的元素

3.1. 基础用法

# 基本范围切片
sub_matrix = x[0:2, 1:3]  # 第1-2行,第2-3列

# 带步长的范围切片
every_other = x[::2, ::3]  # 每隔一行/三列选取

# 反向索引
reversed_rows = x[::-1]  # 行顺序反转

3.2. 高级技巧

# 创建范围索引张量
range_idx = torch.arange(1, 3)  # tensor([1, 2])
selected = x[range_idx]  # 第2-3行

# 结合步长和偏移
strided = x[1::2, ::2]  # 从第2行开始每隔一行,所有列每隔一个

3.3. 内存特性

连续范围:返回视图,不复制数据

非连续范围:可能触发拷贝操作

最佳实践:尽量使用基础切片而非arange创建的范围

4. 布尔索引

布尔索引是根据条件来选择张量中的元素。它使用一个布尔数组或条件表达式来判断哪些元素符合条件,从而选择它们。

4.1. 基础用法

# 创建布尔掩码
mask = x > 5
# tensor([[False, False, False, False],
#        [False,  True,  True,  True],
#        [ True,  True,  True,  True]])

# 应用布尔索引
selected = x[mask]  # tensor([6, 7, 8, 9, 10, 11, 12])

# 条件赋值
x[x % 2 == 0] = 0  # 将所有偶数置0

4.2. 高级用法

# 多条件组合
mask = (x > 3) & (x < 9)
selected = x[mask]

# 按行/列条件索引
row_mask = torch.any(x > 10, dim=1)  # 选择包含大于10的元素的整行
selected_rows = x[row_mask]

4.3. 性能注意事项

掩码创建:布尔操作会创建临时张量

内存占用:大张量的布尔掩码会消耗大量内存

GPU优势:布尔索引在CUDA上并行化效果极佳

5. 多维索引

多维索引可以是混合多种索引方式,包括整数索引、切片索引、布尔索引等。它让你能够根据复杂的条件或结构对张量进行切片和访问

5.1. 基础用法

# 创建3D张量
y = torch.randn(2, 3, 4)  # batch=2, seq_len=3, features=4

# 各维度单独索引
elem = y[1, 2, 3]  # 第2个batch,第3个序列,第4个特征

# 混合索引方式
sub_tensor = y[1, :, [0, 2]]  # 第2个batch,所有序列,第1和3个特征

5.2. 高级模式

# 使用Ellipsis(...)简化索引
first_batch_all_features = y[0, ...]  # 等价于 y[0, :, :]

# 使用None增加维度
expanded = y[:, None, :, :]  # 在第二维增加一个维度

# 跨维度索引
diag = y.diagonal(dim1=1, dim2=2)  # 获取每个batch的特征对角线

5.3. 工程实践

维度顺序:注意PyTorch的通道优先约定(N, C, H, W)

广播机制:了解索引操作中的广播规则

视图与拷贝:复杂索引可能触发意外拷贝

6. 综合性能比较

操作类型返回视图内存效率GPU加速比适用场景
简单索引10-100x常规子矩阵提取
列表索引5-20x非连续元素选择
范围索引通常10-50x连续区块操作
布尔索引20-100x条件筛选
多维索引有时不定10-50x高维数据操作

使用总结

  • 简单行列索引:基础的整数索引,用来访问单个元素。
  • 列表索引:通过提供一个索引列表或数组来选择多个元素。
  • 范围索引:通过切片来选择张量的一个区间。
  • 布尔索引:通过布尔条件来选择符合条件的元素。
  • 多维索引:通过混合使用不同的索引方式,进行复杂的索引操作。

7. 最佳实践建议

优先使用简单索引:性能最佳,内存最友好

避免频繁的小规模索引:合并多个操作为一个

注意设备一致性:索引张量应与数据在同一设备

利用原地操作:对于大张量修改,使用_后缀方法

预分配内存:对于已知大小的结果,先创建目标张量

# 高效索引操作示例
def efficient_indexing(x, row_indices, col_indices):
    # 预分配结果张量
    result = torch.empty(len(row_indices), 
                        len(col_indices),
                        device=x.device)
    
    # 批量索引操作
    torch.index_select(x, 0, row_indices, out=result)
    torch.index_select(result, 1, col_indices, out=result)
    
    return result

相关文章:

  • uniapp小程序生成海报/图片并保存分享
  • 集合学习内容总结
  • Chrome 135 版本新特性
  • YUESAI应急4G网络广播成功应用于绍兴市钱塘江观潮预警提示项目
  • 【9】搭建k8s集群系列(二进制部署)之安装work-node节点组件(kube-proxy)和网络组件calico
  • QT ARM开发板调试
  • 《从零搭建Vue3项目实战》(AI辅助搭建Vue3+ElemntPlus后台管理项目)零基础入门系列第二篇:项目创建和初始化
  • Linux时间函数3-strftime时间格式转换、asctime时间固定格式、asctime_r线程安全、strftime/asctime/ctime区别
  • 组合与括号生成(回溯)
  • 开源模型应用落地-Qwen2.5-Omni-7B模型-Gradio-部署 “光速” 指南(二)
  • 2012年-全国大学生数学建模竞赛(CUMCM)试题速浏、分类及浅析
  • React-04React组件状态(state),构造器初始化state以及数据读取,添加点击事件并更改state状态值
  • 深度学习篇---Prophet时间序列预测工具
  • 使用stm32cubeide stm32f407 lan8720a freertos lwip 实现udp client网络数据转串口数据过程详解
  • Scala相关知识学习总结5
  • 简述Unity对多线程的支持限制和注意事项
  • 【橘子大模型】使用streamlit来构建自己的聊天机器人(下)
  • echarts生成3D立体地图react组件
  • T-SQL语言的压力测试
  • Redis 面经
  • 北斗系统全面进入11个国际组织的标准体系
  • 缅甸发生5.0级地震
  • 学者三年实地调查被判AI代笔,论文AI率检测如何避免“误伤”
  • 侵害孩子者,必严惩不贷!3名性侵害未成年人罪犯被执行死刑
  • 国税总局上海市税务局通报:收到王某对刘某某及相关企业涉税问题举报,正依法依规办理
  • 美将解除对叙利亚制裁,外交部:中方一贯反对非法单边制裁