Point Transformer V3(PTv3)【3:上采样unpooling】
PTV3专题目录
序列化编码
降采样SerializedPooling
上采样SerializedUnpooling
文章目录
- 序列化编码
- 降采样SerializedPooling
- 上采样SerializedUnpooling
- 背景
- 基本功能
- 具体思路
- 1. 总体目标
- 2. 核心原理:利用池化时保存的信息
- 3. `__init__` (初始化)
- 4. `forward` (前向传播) - 详细步骤
- QA
- `一、point.feat[inverse]的实现,其实就是根据inverse把特征进行复制操作吧`
背景
点云分割的原始代码
class SerializedUnpooling(PointModule):def __init__(self,in_channels,skip_channels,out_channels,norm_layer=None,act_layer=None,traceable=False, # record parent and cluster):super().__init__()self.proj = PointSequential(nn.Linear(in_channels, out_channels))self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))if norm_layer is not None:self.proj.add(norm_layer(out_channels))self.proj_skip.add(norm_layer(out_channels))if act_layer is not None:self.proj.add(act_layer())self.proj_skip.add(act_layer())self.traceable = traceabledef forward(self, point):assert "pooling_parent" in point.keys()assert "pooling_inverse" in point.keys()parent = point.pop("pooling_parent")inverse = point.pop("pooling_inverse")point = self.proj(point)parent = self.proj_skip(parent)parent.feat = parent.feat + point.feat[inverse]if self.traceable:parent["unpooling_parent"] = pointreturn parent
基本功能
SerializedUnpooling
的原理可以概括为:
利用池化时记录的“父子关系”(pooling_parent
和 pooling_inverse
),将低分辨率点云的特征精准地“广播”回高分辨率的空间结构中,并与该层的跳跃连接特征相加,实现信息的上采样和融合。
这个过程完全依赖于 SerializedPooling
提供的溯源信息,两者构成了一个高效且完全可逆的编解码对。
具体思路
我们来详细解析 SerializedUnpooling
的工作原理。它是 SerializedPooling
的逆操作,在典型的 U-Net 架构中扮演着上采样(Upsampling)和特征融合的关键角色。
1. 总体目标
SerializedUnpooling
的目标是将在网络深层、低分辨率的点云特征图(点数少,但语义信息丰富)恢复到其在网络浅层时的高分辨率(点数多,但空间细节丰富),并融合来自浅层的特征。
简单来说,它要回答这个问题:如何将一个点的特征“分配”回当初合并成它的那 N 个点?
2. 核心原理:利用池化时保存的信息
SerializedUnpooling
的“魔法”完全依赖于 SerializedPooling
在执行下采样时,有先见之明地保存了两个关键信息:
pooling_parent
: 这是一个指向池化前的、高分辨率的Point
对象的引用。它包含了原始的点云坐标、特征(也就是跳跃连接 (Skip Connection) 的特征)以及所有序列化信息。pooling_inverse
: 这是一个索引张量。如果池化前的点云有 N 个点,池化后有 M 个点,那么pooling_inverse
就是一个长度为 N 的张量。它的第i
个元素的值j
表示:原始的第i
个点在池化时被合并到了新的第j
个点中。
3. __init__
(初始化)
在创建 SerializedUnpooling
实例时,会定义几个关键部分:
in_channels
: 输入的低分辨率点云的特征维度。skip_channels
: 来自pooling_parent
(跳跃连接) 的高分辨率点云的特征维度。out_channels
: 最终输出的高分辨率点云的特征维度。self.proj
: 一个线性层,用于处理来自深层网络(低分辨率)的特征,将其维度从in_channels
映射到out_channels
。self.proj_skip
: 另一个线性层,用于处理来自跳跃连接(高分辨率)的特征,将其维度从skip_channels
映射到out_channels
。
4. forward
(前向传播) - 详细步骤
我们用一个具体的例子来贯穿整个流程。假设 SerializedPooling
将一个 14 个点的点云(我们称之为 P_high
)池化成了一个 2 个点的点云(我们称之为 P_low
)。
现在,SerializedUnpooling
的输入 point
就是 P_low
。
-
断言和信息恢复:
assert "pooling_parent" in point.keys()
: 检查P_low
是否保存了指向P_high
的引用。parent = point.pop("pooling_parent")
: 取出这个引用,现在parent
就是P_high
(包含14个点及其原始特征)。inverse = point.pop("pooling_inverse")
: 取出那个长度为 14 的索引张量,其内容类似[0,0,0,0,0,0,0,0, 1,1,1,1,1,1]
。
-
特征投影:
point = self.proj(point)
: 将P_low
的特征(2个点)通过线性层进行变换。假设维度从(2, in_channels)
变为(2, out_channels)
。parent = self.proj_skip(parent)
: 将P_high
的跳跃连接特征(14个点)也通过线性层进行变换。维度从(14, skip_channels)
变为(14, out_channels)
。- 现在,深层特征和浅层特征都有了相同的维度
out_channels
,为相加做好了准备。
-
核心操作:特征广播 (Broadcast) 与融合:
- 这是最关键的一步:
parent.feat = parent.feat + point.feat[inverse]
- 我们来分解
point.feat[inverse]
:point.feat
是一个(2, out_channels)
的张量,包含两个低分辨率点的特征。inverse
是一个(14,)
的索引张量。point.feat[inverse]
是一个高级索引操作。它会根据inverse
中的值,为parent
中的每一个点(共14个)从point.feat
中“拾取”对应的特征。- 由于
inverse
的前8个值是0,所以新生成张量的前8行都会是point.feat
的第0行。 - 由于
inverse
的后6个值是1,所以新生成张量的后6行都会是point.feat
的第1行。 - 最终,
point.feat[inverse]
的结果是一个(14, out_channels)
的张量。这个操作高效地将低分辨率的特征广播回了原始高分辨率点云的每个点上。
parent.feat + ...
: 将广播后的特征与parent
自身经过投影的特征进行逐元素相加。这就完成了深层语义信息和浅层空间细节的融合。
- 这是最关键的一步:
-
返回结果:
- 函数返回更新后的
parent
对象。它现在是一个高分辨率的点云(14个点),但其特征已经融合了来自网络深层的信息。
- 函数返回更新后的
QA
一、point.feat[inverse]的实现,其实就是根据inverse把特征进行复制操作吧
是的,理解非常准确!
point.feat[inverse]
的实现,本质上就是一次高效的**复制和广播(Broadcast)**操作。它利用 inverse
张量作为索引,将低分辨率的特征“复制”并“扩展”到高分辨率的点云结构上。
用之前的例子来具体说明这个过程:
-
输入状态:
point.feat
: 这是低分辨率点云的特征张量。假设池化后剩下 2 个点,特征维度为 64。那么point.feat
的形状是(2, 64)
。我们可以把它看作[特征_A, 特征_B]
。inverse
: 这是在池化时生成的索引张量,记录了每个原始点属于哪个新聚类。它的长度是高分辨率点云的点数,比如 14。它的内容是[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
。
-
point.feat[inverse]
操作:- 这是一个 PyTorch 的高级索引(Advanced Indexing)操作。
- 它会创建一个新的形状为
(14, 64)
的张量。 - 它遍历
inverse
中的每一个元素:- 对于
inverse
的前 8 个元素,值都是0
。所以,它会取出point.feat
中索引为0
的行(也就是特征_A
),并把这个特征复制 8 次,作为新张量的前 8 行。 - 对于
inverse
的后 6 个元素,值都是1
。所以,它会取出point.feat
中索引为1
的行(也就是特征_B
),并把这个特征复制 6 次,作为新张量的后 6 行。
- 对于
-
结果:
- 最终
point.feat[inverse]
的结果是一个(14, 64)
的张量,其内容看起来像:[特征_A, // 对应原始点1特征_A, // 对应原始点2... (共8行)特征_A, // 对应原始点8特征_B, // 对应原始点9... (共6行)特征_B // 对应原始点14 ]
- 最终
总结来说,point.feat[inverse]
这行代码用一种极其高效和向量化的方式,完成了“将一个点的特征广播给所有属于它的子点”这一任务,是实现上采样的核心步骤。