【点云】pointnet网络梳理
every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog
0. 前言
pointnet
1. 正文
1.1 原因
在图像领域可以使用CNN,但是在点云中,由于点云的无序性,即,交换两个点,不会改变点云的性质。
同时,点云的分布不是均匀的,有的地方会密集,有的地方稀疏。
如何解决对于无序性:
PointNet的作者想:我需要找到一种方法,无论点的顺序怎么打乱,我最后得到的特征都是一样的。这种特性在数学上叫做“对称函数” (Symmetric Function)。
常见的对称函数有哪些?
求和 (Summation): a+b+c 和 c+b+a 的结果是一样的。
求平均 (Average): (a+b+c)/3 和 (c+b+a)/3 的结果是一样的。
求最大值 (Max-Pooling): max(a, b, c) 和 max(c, b, a) 的结果是一样的。
PointNet最终选择了 Max-Pooling(最大池化) 作为它的核心对称函数。
1.2 网络结构
T-Net后续改进版的网络好像没用到,所以这里就不讲了。
整体的话就是一个MLP堆叠的过程,其中用maxpooling来保证无序性。
分类的话,就是最用有一个k维进行分类;
分割的话,会将局部特征和全局特征进行合并操作,然后进行分割。
1.3 流程
简单来说:
先升维(n,3)-> (n,1024)
再maxpooling (n,1024)-> (1,1024)
最后分类
- 独立特征提取:
输入是一堆点的坐标(一个大小为 N x 3 的矩阵,N是点的数量,3是XYZ坐标)。
PointNet对 每一个点 单独进行特征学习,把它从3维映射到更高维度的空间(比如1024维)。你可以把它想象成给每个点“画像”,让它的信息更丰富。这一步是通过几个共享参数的多层感知机(MLP)完成的。
- 全局特征聚合(关键一步):
现在我们有N个1024维的特征向量了。
PointNet在这些特征向量的 每一个维度上 做一次Max-Pooling。也就是说,在第一个维度上,从N个点中选出最大值;在第二个维度上,也选出最大值……以此类推。
做完之后,N个点的特征就被“压”成了一个1024维的 全局特征向量。这个向量代表了整个点云的“样子”。因为Max-Pooling是无序的,所以无论输入点的顺序如何,这个全局特征都是不变的!
- 输出结果:
最后,用这个全局特征向量去做具体的任务,比如接一个分类器判断这个点云是什么物体(桌子?椅子?),或者做一个分割器判断每个点属于物体的哪个部分。
1.4 代码
1.4.1 特征提取部分
特征提取部分,就是MLP堆叠的过程。
其中的T-Net对应代码中的STN3d,后续的网络中没用到,就不讲了。
最后返回两种结果,一种是(1,1024)用于后续的分类,另一种的和前面的特征拼接,用于后续的分割。
class PointNetEncoder(nn.Module):def __init__(self, global_feat=True, feature_transform=False, channel=3):super(PointNetEncoder, self).__init__()self.stn = STN3d(channel)self.conv1 = torch.nn.Conv1d(channel, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):B, D, N = x.size()trans = self.stn(x)x = x.transpose(2, 1)if D > 3:feature = x[:, :, 3:]x = x[:, :, :3]x = torch.bmm(x, trans)if D > 3:x = torch.cat([x, feature], dim=2)x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x)))if self.feature_transform:trans_feat = self.fstn(x)x = x.transpose(2, 1)x = torch.bmm(x, trans_feat)x = x.transpose(2, 1)else:trans_feat = Nonepointfeat = x # (b,n,64)x = F.relu(self.bn2(self.conv2(x))) # 64-> 128x = self.bn3(self.conv3(x)) # 128->1024 x = torch.max(x, 2, keepdim=True)[0] # maxx = x.view(-1, 1024) # 1024# 分类最后的返回结果if self.global_feat:return x, trans, trans_feat# 分割最后的返回结果else:x = x.view(-1, 1024, 1).repeat(1, 1, N)# 中间的(n,64)和global feature(1,1024)拼接, ,return torch.cat([x, pointfeat], 1), trans, trans_feat
1.4.2 分类
堆叠mlp ,通道层的变化:1024->512->256->128->k
class get_model(nn.Module):def __init__(self, k=40, normal_channel=True):super(get_model, self).__init__()if normal_channel:channel = 6else:channel = 3self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.4)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):# 1024 -> 512 --> 256 --> kx, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)x = F.log_softmax(x, dim=1)return x, trans_feat
1.4.3 分割
class get_model(nn.Module):def __init__(self, num_class):super(get_model, self).__init__()self.k = num_classself.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=9)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x) # ,这里x最后经过了拼接:1024+64 =1088x = F.relu(self.bn1(self.conv1(x))) # 1088->512x = F.relu(self.bn2(self.conv2(x))) # 512->256x = F.relu(self.bn3(self.conv3(x))) # 256->128x = self.conv4(x) # 128->kx = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans_feat
参考
- https://www.cnblogs.com/SkyXZ/p/19138265
- https://blog.csdn.net/Yong_Qi2015/article/details/128509981
- https://zhuanlan.zhihu.com/p/264627148
- https://binaryoracle.github.io/3DVL/%E7%AE%80%E6%9E%90PointNet.html#%E7%BC%BA%E9%99%B7
- https://www.hubtools.cn/2025/PointNet.html