当前位置: 首页 > news >正文

Pytorch中张量的索引和切片使用详解和代码示例

PyTorch 中张量索引与切片详解

使用前先导入:

import torch

1.基础索引(类似 Python / NumPy)

适用于低维张量:x[i]x[i, j]

x = torch.tensor([[10, 11, 12],[13, 14, 15],[16, 17, 18]])print(x[0])         # 第0行: tensor([10, 11, 12])
print(x[1][2])      # 第1行第2列: 15
print(x[2, 1])      # 第2行第1列: 17

2.切片(Slicing)

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])print(x[:2])        # 前两行
print(x[:, 1:3])    # 所有行,第1~2列
print(x[::2, ::2])  # 行列间隔为2

3.负索引

print(x[-1])        # 最后一行
print(x[:, -2:])    # 每行最后两列

4.使用 ... (Ellipsis)

当维度很多时可简化操作。

x = torch.arange(2*3*4).reshape(2, 3, 4)# 等价于 x[0, :, 2]
print(x[0, ..., 2])

5.Noneunsqueeze 增加维度

x = torch.tensor([1, 2, 3])# 增加维度(等价于 unsqueeze)
print(x[None, :].shape)     # torch.Size([1, 3])
print(x[:, None].shape)     # torch.Size([3, 1])

6. 布尔索引(Boolean Indexing)

x = torch.tensor([10, 20, 30, 40])mask = x > 25
print(mask)         # tensor([False, False,  True,  True])
print(x[mask])      # tensor([30, 40])

7. 花式索引(Fancy Indexing)

使用索引列表访问多个非连续位置。

x = torch.tensor([10, 20, 30, 40, 50])idx = torch.tensor([0, 2, 4])
print(x[idx])       # tensor([10, 30, 50])

二维花式索引:

x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])rows = torch.tensor([0, 1, 2])
cols = torch.tensor([2, 1, 0])
print(x[rows, cols])  # [3, 5, 7]

8. 条件赋值 / where

x = torch.tensor([1, 2, 3, 4, 5])
x[x > 3] = 100
print(x)            # tensor([  1,   2,   3, 100, 100])# 条件选择
a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
cond = torch.tensor([True, False, True])print(torch.where(cond, a, b))  # -> [1, 20, 3]

9. 高维张量索引技巧

x = torch.arange(2*3*4).reshape(2, 3, 4)# 提取第1个 batch 所有通道第2列
print(x[0, :, 2])    # shape: (3,)

10. 实例:图像张量裁剪(HWC)

img = torch.rand((3, 256, 256))  # C, H, W 格式# 裁剪中心区域
crop = img[:, 100:200, 100:200]  # shape (3, 100, 100)

11. 总结图解(结构化索引方式)

张量索引方式:
├── 基础索引(x[i], x[i,j])
├── 切片(x[start:end], x[:, idx])
├── 高维省略(x[..., -1])
├── 增维/降维(x[None, :], x.squeeze())
├── 布尔索引(x[x>val])
├── 花式索引(x[[0, 2, 4]])
├── 条件赋值(x[x > a] = b)
└── torch.where(cond, a, b)

高级应用


1. 高级花式索引(Advanced Fancy Indexing)

基本复习:

花式索引是用整张或部分张量作为索引,获取非连续元素。进阶里,张量的形状组合、广播规则非常重要。

代码示例:

import torchx = torch.arange(27).reshape(3, 3, 3)
# x shape = (3, 3, 3)# 目标:同时选取不同 batch 不同通道的元素
idx_batch = torch.tensor([0, 1, 2])   # 每个 batch 索引
idx_channel = torch.tensor([2, 1, 0]) # 每个对应通道索引
idx_row = torch.tensor([0, 1, 2])     # 对应行索引# 三个索引张量自动广播,选出:
# x[0, 2, 0], x[1, 1, 1], x[2, 0, 2]
result = x[idx_batch, idx_channel, idx_row]print(result)  # tensor([ 6, 13, 24])
  • 关键是各个索引张量形状要匹配或可广播
  • 返回值的形状取决于索引张量的形状。

2. 坐标映射索引(Indexing with Coordinate Tensors)

常用在点云、图像坐标映射,手工给定索引位置批量取值。

代码示例:

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])# 给定坐标点
coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 三个点的坐标rows = coords[:, 0]
cols = coords[:, 1]vals = x[rows, cols]
print(vals)  # tensor([ 1, 11, 12])

torch.gather — 按索引沿指定维度收集数据

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])indices = torch.tensor([[0, 3], [2, 1], [1, 0]])
result = torch.gather(x, dim=1, index=indices)
print(result)
# tensor([[ 0,  3],
#         [ 6,  5],
#         [ 9,  8]])
  • torch.gather 需要索引张量与输入同形状,但索引值表示该维度的选取位置。

3. 高维图像张量处理技巧

假设图像张量格式为 (Batch, Channels, Height, Width),称为 BCHW。

常用操作示例:

(a) 批量裁剪 (Crop)
img = torch.randn(5, 3, 256, 256)  # 5张RGB图像# 取中心128x128块
h_start = (256 - 128) // 2
w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128]  # shape (5, 3, 128, 128)
(b) 改变通道顺序
# BCHW -> BHWC
img_bhwc = img.permute(0, 2, 3, 1)
print(img_bhwc.shape)  # (5, 256, 256, 3)
© 按坐标索引批量像素点
batch_size = 2
img = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)# 取每张图(0,1)通道,指定像素点坐标
coords = torch.tensor([[1, 2], [3, 0]])  # (batch_size, 2) 像素坐标 (H, W)batch_indices = torch.arange(batch_size)
channels = torch.tensor([0, 1])  # 不同图不同通道pixels = img[batch_indices, channels, coords[:, 0], coords[:, 1]]
print(pixels)

总结:

技巧类别适用场景关键函数/概念
高级花式索引多维非连续索引,索引张量广播多张量索引广播
坐标映射索引点云坐标、图像点批量索引torch.gather, 坐标张量索引
高维图像张量处理批量裁剪、通道转换、批量像素选取permutereshape、多维切片

4.综合示例

下面以一个综合示例代码,涵盖 高级花式索引坐标映射索引,以及 高维图像张量处理,注释详尽,方便大家理解和直接跑起来。

import torchdef advanced_fancy_indexing():print("=== 高级花式索引示例 ===")x = torch.arange(27).reshape(3, 3, 3)idx_batch = torch.tensor([0, 1, 2])idx_channel = torch.tensor([2, 1, 0])idx_row = torch.tensor([0, 1, 2])# 选出 x[0,2,0], x[1,1,1], x[2,0,2]result = x[idx_batch, idx_channel, idx_row]print(result)  # tensor([ 6, 13, 24])print()def coordinate_mapping_indexing():print("=== 坐标映射索引示例 ===")x = torch.arange(16).reshape(4, 4)coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 3个坐标点rows = coords[:, 0]cols = coords[:, 1]vals = x[rows, cols]print(f"从坐标 {coords.tolist()} 取值: {vals.tolist()}")# torch.gather示例x2 = torch.arange(12).reshape(3, 4)indices = torch.tensor([[0, 3], [2, 1], [1, 0]])gathered = torch.gather(x2, dim=1, index=indices)print(f"torch.gather 结果:\n{gathered}")print()def high_dim_image_tensor_processing():print("=== 高维图像张量处理示例 ===")# 生成一个 5张RGB图像 BCHW 格式img = torch.randn(5, 3, 256, 256)# 裁剪中心128x128h_start = (256 - 128) // 2w_start = (256 - 128) // 2crop = img[:, :, h_start:h_start+128, w_start:w_start+128]print(f"裁剪后的形状: {crop.shape}")# 通道顺序变换 BCHW -> BHWCimg_bhwc = img.permute(0, 2, 3, 1)print(f"通道转换后形状: {img_bhwc.shape}")# 批量取像素点batch_size = 2img_small = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)coords = torch.tensor([[1, 2], [3, 0]])  # 每张图像的像素坐标 (H, W)batch_indices = torch.arange(batch_size)channels = torch.tensor([0, 1])  # 两张图不同通道pixels = img_small[batch_indices, channels, coords[:, 0], coords[:, 1]]print(f"批量像素值: {pixels.tolist()}")if __name__ == "__main__":advanced_fancy_indexing()coordinate_mapping_indexing()high_dim_image_tensor_processing()

代码说明

  • advanced_fancy_indexing()
    演示多张量广播索引从三维张量中选取不规则元素。

  • coordinate_mapping_indexing()
    演示给定坐标点批量取值 + 用 torch.gather 沿某维度收集。

  • high_dim_image_tensor_processing()
    展示了高维图像张量裁剪、通道排列变换和批量像素点采样。



文章转载自:
http://caenogenesis.hyyxsc.cn
http://aright.hyyxsc.cn
http://canid.hyyxsc.cn
http://biocatalyst.hyyxsc.cn
http://axial.hyyxsc.cn
http://ascensionist.hyyxsc.cn
http://asynapsis.hyyxsc.cn
http://anatomist.hyyxsc.cn
http://bookmaker.hyyxsc.cn
http://bridlewise.hyyxsc.cn
http://cadency.hyyxsc.cn
http://armchair.hyyxsc.cn
http://bowhunt.hyyxsc.cn
http://burnous.hyyxsc.cn
http://certiorari.hyyxsc.cn
http://athrocyte.hyyxsc.cn
http://alcoholize.hyyxsc.cn
http://centre.hyyxsc.cn
http://anthem.hyyxsc.cn
http://bud.hyyxsc.cn
http://cataclastic.hyyxsc.cn
http://baconian.hyyxsc.cn
http://chrysalides.hyyxsc.cn
http://brutalitarian.hyyxsc.cn
http://ayuntamiento.hyyxsc.cn
http://aerophore.hyyxsc.cn
http://billfold.hyyxsc.cn
http://blissfully.hyyxsc.cn
http://beret.hyyxsc.cn
http://anicut.hyyxsc.cn
http://www.dtcms.com/a/280359.html

相关文章:

  • CSS的初步学习
  • 用语音识别芯片驱动TFT屏幕还有链接蓝牙功能?
  • cursor使用mcp连接mysql数据库,url方式
  • java截取视频帧
  • c#进阶之数据结构(字符串篇)----String
  • C++中list各种基本接口的模拟实现
  • 【Java代码审计(2)】MyBatis XML 注入审计
  • 153.在 Vue 3 中使用 OpenLayers + Cesium 实现 2D/3D 地图切换效果
  • java中的接口
  • JavaScript 动态访问嵌套对象属性问题记录
  • HarmonyOS-ArkUI: Web组件加载流程1
  • 暴力破解:攻破系统的终极密钥
  • Rust指针选择
  • 安装带GPU的docker环境
  • 20250715使用荣品RD-RK3588开发板在Android13下接入USB3.0接口的红外相机
  • 【I3D 2024】Deblur-GS: 3D Gaussian Splatting from Camera Motion Blurred Images
  • 记录一条面试sql题目
  • JS中async/await功能介绍和使用演示
  • 普通字符类型和new String有什么区别
  • 使用JS编写动态表格
  • 【env环境】rtthread5.1.0使用fal组件
  • AI的外挂知识库,RAG检索增强生成技术
  • 【PTA数据结构 | C语言版】将表达式树转换成中缀表达式
  • 数仓面试题
  • 2025最新国产用例管理工具评测:Gitee Test、禅道、蓝凌测试、TestOps 哪家更懂研发协同?
  • docker停止所有容器和删除所有镜像
  • 从一道题目(阿里2014 Crackme_2)开启unidbg还原算法入门(转载)
  • 强化学习书籍
  • vscode 打开c++文件注释乱码
  • 分布式存储之Ceph使用指南--部署篇(未完待续)