Pytorch中张量的索引和切片使用详解和代码示例
PyTorch 中张量索引与切片详解
使用前先导入:
import torch
1.基础索引(类似 Python / NumPy)
适用于低维张量:x[i]
、x[i, j]
x = torch.tensor([[10, 11, 12],[13, 14, 15],[16, 17, 18]])print(x[0]) # 第0行: tensor([10, 11, 12])
print(x[1][2]) # 第1行第2列: 15
print(x[2, 1]) # 第2行第1列: 17
2.切片(Slicing)
x = torch.arange(16).reshape(4, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15]])print(x[:2]) # 前两行
print(x[:, 1:3]) # 所有行,第1~2列
print(x[::2, ::2]) # 行列间隔为2
3.负索引
print(x[-1]) # 最后一行
print(x[:, -2:]) # 每行最后两列
4.使用 ...
(Ellipsis)
当维度很多时可简化操作。
x = torch.arange(2*3*4).reshape(2, 3, 4)# 等价于 x[0, :, 2]
print(x[0, ..., 2])
5.None
和 unsqueeze
增加维度
x = torch.tensor([1, 2, 3])# 增加维度(等价于 unsqueeze)
print(x[None, :].shape) # torch.Size([1, 3])
print(x[:, None].shape) # torch.Size([3, 1])
6. 布尔索引(Boolean Indexing)
x = torch.tensor([10, 20, 30, 40])mask = x > 25
print(mask) # tensor([False, False, True, True])
print(x[mask]) # tensor([30, 40])
7. 花式索引(Fancy Indexing)
使用索引列表访问多个非连续位置。
x = torch.tensor([10, 20, 30, 40, 50])idx = torch.tensor([0, 2, 4])
print(x[idx]) # tensor([10, 30, 50])
二维花式索引:
x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])rows = torch.tensor([0, 1, 2])
cols = torch.tensor([2, 1, 0])
print(x[rows, cols]) # [3, 5, 7]
8. 条件赋值 / where
x = torch.tensor([1, 2, 3, 4, 5])
x[x > 3] = 100
print(x) # tensor([ 1, 2, 3, 100, 100])# 条件选择
a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
cond = torch.tensor([True, False, True])print(torch.where(cond, a, b)) # -> [1, 20, 3]
9. 高维张量索引技巧
x = torch.arange(2*3*4).reshape(2, 3, 4)# 提取第1个 batch 所有通道第2列
print(x[0, :, 2]) # shape: (3,)
10. 实例:图像张量裁剪(HWC)
img = torch.rand((3, 256, 256)) # C, H, W 格式# 裁剪中心区域
crop = img[:, 100:200, 100:200] # shape (3, 100, 100)
11. 总结图解(结构化索引方式)
张量索引方式:
├── 基础索引(x[i], x[i,j])
├── 切片(x[start:end], x[:, idx])
├── 高维省略(x[..., -1])
├── 增维/降维(x[None, :], x.squeeze())
├── 布尔索引(x[x>val])
├── 花式索引(x[[0, 2, 4]])
├── 条件赋值(x[x > a] = b)
└── torch.where(cond, a, b)
高级应用
1. 高级花式索引(Advanced Fancy Indexing)
基本复习:
花式索引是用整张或部分张量作为索引,获取非连续元素。进阶里,张量的形状组合、广播规则非常重要。
代码示例:
import torchx = torch.arange(27).reshape(3, 3, 3)
# x shape = (3, 3, 3)# 目标:同时选取不同 batch 不同通道的元素
idx_batch = torch.tensor([0, 1, 2]) # 每个 batch 索引
idx_channel = torch.tensor([2, 1, 0]) # 每个对应通道索引
idx_row = torch.tensor([0, 1, 2]) # 对应行索引# 三个索引张量自动广播,选出:
# x[0, 2, 0], x[1, 1, 1], x[2, 0, 2]
result = x[idx_batch, idx_channel, idx_row]print(result) # tensor([ 6, 13, 24])
- 关键是各个索引张量形状要匹配或可广播。
- 返回值的形状取决于索引张量的形状。
2. 坐标映射索引(Indexing with Coordinate Tensors)
常用在点云、图像坐标映射,手工给定索引位置批量取值。
代码示例:
x = torch.arange(16).reshape(4, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15]])# 给定坐标点
coords = torch.tensor([[0, 1], [2, 3], [3, 0]]) # 三个点的坐标rows = coords[:, 0]
cols = coords[:, 1]vals = x[rows, cols]
print(vals) # tensor([ 1, 11, 12])
torch.gather
— 按索引沿指定维度收集数据
x = torch.arange(12).reshape(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])indices = torch.tensor([[0, 3], [2, 1], [1, 0]])
result = torch.gather(x, dim=1, index=indices)
print(result)
# tensor([[ 0, 3],
# [ 6, 5],
# [ 9, 8]])
torch.gather
需要索引张量与输入同形状,但索引值表示该维度的选取位置。
3. 高维图像张量处理技巧
假设图像张量格式为 (Batch, Channels, Height, Width)
,称为 BCHW。
常用操作示例:
(a) 批量裁剪 (Crop)
img = torch.randn(5, 3, 256, 256) # 5张RGB图像# 取中心128x128块
h_start = (256 - 128) // 2
w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128] # shape (5, 3, 128, 128)
(b) 改变通道顺序
# BCHW -> BHWC
img_bhwc = img.permute(0, 2, 3, 1)
print(img_bhwc.shape) # (5, 256, 256, 3)
© 按坐标索引批量像素点
batch_size = 2
img = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)# 取每张图(0,1)通道,指定像素点坐标
coords = torch.tensor([[1, 2], [3, 0]]) # (batch_size, 2) 像素坐标 (H, W)batch_indices = torch.arange(batch_size)
channels = torch.tensor([0, 1]) # 不同图不同通道pixels = img[batch_indices, channels, coords[:, 0], coords[:, 1]]
print(pixels)
总结:
技巧类别 | 适用场景 | 关键函数/概念 |
---|---|---|
高级花式索引 | 多维非连续索引,索引张量广播 | 多张量索引广播 |
坐标映射索引 | 点云坐标、图像点批量索引 | torch.gather , 坐标张量索引 |
高维图像张量处理 | 批量裁剪、通道转换、批量像素选取 | permute 、reshape 、多维切片 |
4.综合示例
下面以一个综合示例代码,涵盖 高级花式索引、坐标映射索引,以及 高维图像张量处理,注释详尽,方便大家理解和直接跑起来。
import torchdef advanced_fancy_indexing():print("=== 高级花式索引示例 ===")x = torch.arange(27).reshape(3, 3, 3)idx_batch = torch.tensor([0, 1, 2])idx_channel = torch.tensor([2, 1, 0])idx_row = torch.tensor([0, 1, 2])# 选出 x[0,2,0], x[1,1,1], x[2,0,2]result = x[idx_batch, idx_channel, idx_row]print(result) # tensor([ 6, 13, 24])print()def coordinate_mapping_indexing():print("=== 坐标映射索引示例 ===")x = torch.arange(16).reshape(4, 4)coords = torch.tensor([[0, 1], [2, 3], [3, 0]]) # 3个坐标点rows = coords[:, 0]cols = coords[:, 1]vals = x[rows, cols]print(f"从坐标 {coords.tolist()} 取值: {vals.tolist()}")# torch.gather示例x2 = torch.arange(12).reshape(3, 4)indices = torch.tensor([[0, 3], [2, 1], [1, 0]])gathered = torch.gather(x2, dim=1, index=indices)print(f"torch.gather 结果:\n{gathered}")print()def high_dim_image_tensor_processing():print("=== 高维图像张量处理示例 ===")# 生成一个 5张RGB图像 BCHW 格式img = torch.randn(5, 3, 256, 256)# 裁剪中心128x128h_start = (256 - 128) // 2w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128]print(f"裁剪后的形状: {crop.shape}")# 通道顺序变换 BCHW -> BHWCimg_bhwc = img.permute(0, 2, 3, 1)print(f"通道转换后形状: {img_bhwc.shape}")# 批量取像素点batch_size = 2img_small = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)coords = torch.tensor([[1, 2], [3, 0]]) # 每张图像的像素坐标 (H, W)batch_indices = torch.arange(batch_size)channels = torch.tensor([0, 1]) # 两张图不同通道pixels = img_small[batch_indices, channels, coords[:, 0], coords[:, 1]]print(f"批量像素值: {pixels.tolist()}")if __name__ == "__main__":advanced_fancy_indexing()coordinate_mapping_indexing()high_dim_image_tensor_processing()
代码说明
-
advanced_fancy_indexing()
演示多张量广播索引从三维张量中选取不规则元素。 -
coordinate_mapping_indexing()
演示给定坐标点批量取值 + 用torch.gather
沿某维度收集。 -
high_dim_image_tensor_processing()
展示了高维图像张量裁剪、通道排列变换和批量像素点采样。