PyTorch使用(5)-张量索引操作
文章目录
- 1. 简单行、列索引
- 1.1. 基础用法
- 1.2. 工程实践要点
- 2. 列表索引
- 2.1. 基础用法
- 2.2. 高级用法
- 2.3. 性能考虑
- 3. 范围索引
- 3.1. 基础用法
- 3.2. 高级技巧
- 3.3. 内存特性
- 4. 布尔索引
- 4.1. 基础用法
- 4.2. 高级用法
- 4.3. 性能注意事项
- 5. 多维索引
- 5.1. 基础用法
- 5.2. 高级模式
- 5.3. 工程实践
- 6. 综合性能比较
- 7. 最佳实践建议
1. 简单行、列索引
简单的行、列索引是最基本的索引操作,通过整数来访问张量中的元素。可以使用类似数组索引的方式来操作。
1.1. 基础用法
import torch
# 创建一个3x4的矩阵
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 获取第2行(索引从0开始)
row = x[1] # tensor([5, 6, 7, 8])
# 获取第3列
col = x[:, 2] # tensor([3, 7, 11])
1.2. 工程实践要点
内存视图:简单索引返回的是原张量的视图,不复制数据
性能:O(1)时间复杂度,是最快的索引方式
GPU兼容:在CUDA张量上同样高效
# 获取连续多行/多列
rows = x[1:3] # 第2-3行
cols = x[:, 1:3] # 第2-3列
2. 列表索引
列表索引是通过一个列表或数组来选择张量中的多个元素。这种索引方式可以选择多个位置的元素,并返回一个新的张量。
2.1. 基础用法
# 使用列表选择特定行
selected_rows = x[[0, 2]] # 第1和第3行
# 使用列表选择特定列
selected_cols = x[:, [1, 3]] # 第2和第4列
# 选择特定元素
elements = x[[0, 1, 2], [1, 2, 3]] # (0,1), (1,2), (2,3)位置的元素
2.2. 高级用法
# 创建索引张量(比Python列表更高效)
indices = torch.tensor([0, 2], device=x.device)
selected = x[indices] # 第1和第3行
# 组合行列索引
x[[[0], [2]], [1, 3]] # 第1/3行的第2/4列 → 2x2矩阵
2.3. 性能考虑
内存开销:列表索引会创建新张量,复制数据
替代方案:对于连续索引,优先使用切片
GPU优化:将索引张量放在与数据相同的设备上
3. 范围索引
范围索引允许你选择张量的一个切片,类似于 Python 列表的切片操作。通过起始索引和结束索引来选择一段连续的元素
3.1. 基础用法
# 基本范围切片
sub_matrix = x[0:2, 1:3] # 第1-2行,第2-3列
# 带步长的范围切片
every_other = x[::2, ::3] # 每隔一行/三列选取
# 反向索引
reversed_rows = x[::-1] # 行顺序反转
3.2. 高级技巧
# 创建范围索引张量
range_idx = torch.arange(1, 3) # tensor([1, 2])
selected = x[range_idx] # 第2-3行
# 结合步长和偏移
strided = x[1::2, ::2] # 从第2行开始每隔一行,所有列每隔一个
3.3. 内存特性
连续范围:返回视图,不复制数据
非连续范围:可能触发拷贝操作
最佳实践:尽量使用基础切片而非arange创建的范围
4. 布尔索引
布尔索引是根据条件来选择张量中的元素。它使用一个布尔数组或条件表达式来判断哪些元素符合条件,从而选择它们。
4.1. 基础用法
# 创建布尔掩码
mask = x > 5
# tensor([[False, False, False, False],
# [False, True, True, True],
# [ True, True, True, True]])
# 应用布尔索引
selected = x[mask] # tensor([6, 7, 8, 9, 10, 11, 12])
# 条件赋值
x[x % 2 == 0] = 0 # 将所有偶数置0
4.2. 高级用法
# 多条件组合
mask = (x > 3) & (x < 9)
selected = x[mask]
# 按行/列条件索引
row_mask = torch.any(x > 10, dim=1) # 选择包含大于10的元素的整行
selected_rows = x[row_mask]
4.3. 性能注意事项
掩码创建:布尔操作会创建临时张量
内存占用:大张量的布尔掩码会消耗大量内存
GPU优势:布尔索引在CUDA上并行化效果极佳
5. 多维索引
多维索引可以是混合多种索引方式,包括整数索引、切片索引、布尔索引等。它让你能够根据复杂的条件或结构对张量进行切片和访问
5.1. 基础用法
# 创建3D张量
y = torch.randn(2, 3, 4) # batch=2, seq_len=3, features=4
# 各维度单独索引
elem = y[1, 2, 3] # 第2个batch,第3个序列,第4个特征
# 混合索引方式
sub_tensor = y[1, :, [0, 2]] # 第2个batch,所有序列,第1和3个特征
5.2. 高级模式
# 使用Ellipsis(...)简化索引
first_batch_all_features = y[0, ...] # 等价于 y[0, :, :]
# 使用None增加维度
expanded = y[:, None, :, :] # 在第二维增加一个维度
# 跨维度索引
diag = y.diagonal(dim1=1, dim2=2) # 获取每个batch的特征对角线
5.3. 工程实践
维度顺序:注意PyTorch的通道优先约定(N, C, H, W)
广播机制:了解索引操作中的广播规则
视图与拷贝:复杂索引可能触发意外拷贝
6. 综合性能比较
操作类型 | 返回视图 | 内存效率 | GPU加速比 | 适用场景 |
---|---|---|---|---|
简单索引 | 是 | 高 | 10-100x | 常规子矩阵提取 |
列表索引 | 否 | 中 | 5-20x | 非连续元素选择 |
范围索引 | 通常 | 高 | 10-50x | 连续区块操作 |
布尔索引 | 否 | 低 | 20-100x | 条件筛选 |
多维索引 | 有时 | 不定 | 10-50x | 高维数据操作 |
使用总结
- 简单行列索引:基础的整数索引,用来访问单个元素。
- 列表索引:通过提供一个索引列表或数组来选择多个元素。
- 范围索引:通过切片来选择张量的一个区间。
- 布尔索引:通过布尔条件来选择符合条件的元素。
- 多维索引:通过混合使用不同的索引方式,进行复杂的索引操作。
7. 最佳实践建议
优先使用简单索引:性能最佳,内存最友好
避免频繁的小规模索引:合并多个操作为一个
注意设备一致性:索引张量应与数据在同一设备
利用原地操作:对于大张量修改,使用_后缀方法
预分配内存:对于已知大小的结果,先创建目标张量
# 高效索引操作示例
def efficient_indexing(x, row_indices, col_indices):
# 预分配结果张量
result = torch.empty(len(row_indices),
len(col_indices),
device=x.device)
# 批量索引操作
torch.index_select(x, 0, row_indices, out=result)
torch.index_select(result, 1, col_indices, out=result)
return result