rk3588移植部署pointnet
目录
pointnet 网络模型修改
pointnet训练
pointnet 导出pth, onnx模型
onnx模型推理测试
onnx模型转换rknn模型
rknn模型部署推理测试rk3588点云分类
推理结果
优化方向
onnx, rknn模型下载
pointnet 网络模型修改
模型修改比较简单,其目的在于适配rknn支持的输入张量格式,由于rknn目前只支持图像格式的张量输入,nhwc,nchw 其他格式不支持。然而pointnet的模型输入张量是16x2048x3, 其中16是batch,2048是固定点云数,3是通道数,因此需要做张量转换,维度扩展,把16x2048x3,扩展为1x16x2048x3以适配nhwc格式。
代码如下:
需要修改model.py中网络模型定义部分的前向传播,假设输入的张量维度是1x16x2048x3。
去掉第一维,交换2,3维度。转为推理模型可以识别的张量格式。
class PointNetCls(nn.Module):def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):#print("input x.shape")#print(x.shape)x = x.squeeze(0) # -> (batch, 1024, 3)x = x.permute(0, 2, 1) # -> (batch, 3, 1024)#print("trans x.shape")#print(x.shape)x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1), trans, trans_feat
pointnet训练
训练代码如下:test_classification_nhwc.py
其中需要注意的部分,在前向传播前,将数据集的点云张量格式转换维度,适配nhwc模型输入如下:
points = transform_points(points) # (batch,3,1024) ->(1,batch,1024,3)
其中转换点维度函数如下:
def transform_points(points):
# points形状: (32,3,1024)
points = points.unsqueeze(0) # -> (1,32,3,1024)
points = points.permute(0,1,3,2) # -> (1,32,1024,3)
return points.contiguous()
把原始的点云输入张量batchxchannelsxpoint_num转换为 1xbatchxchannelsxpoint_num
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
from pointnet.model import PointNetCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdmdef parse_args():'''PARAMETERS'''parser = argparse.ArgumentParser()parser.add_argument('--batchSize', type=int, default=16, help='input batch size')parser.add_argument('--num_points', type=int, default=2500, help='input batch size')parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)parser.add_argument('--nepoch', type=int, default=250, help='number of epochs to train for')parser.add_argument('--outf', type=str, default='cls', help='output folder')parser.add_argument('--model', type=str, default='', help='model path')parser.add_argument('--dataset', type=str, required=True, help="dataset path")parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")parser.add_argument('--feature_transform', action='store_true', help="use feature transform")return parser.parse_args()def transform_points(points):# points形状: (32,3,1024)points = points.unsqueeze(0) # -> (1,32,3,1024)points = points.permute(0,1,3,2) # -> (1,32,1024,3)return points.contiguous()def main(args):opt = parse_args()print(opt)blue = lambda x: '\033[94m' + x + '\033[0m'opt.manualSeed = random.randint(1, 10000) # fix seedprint("Random Seed: ", opt.manualSeed)random.seed(opt.manualSeed)torch.manual_seed(opt.manualSeed)if opt.dataset_type == 'shapenet':dataset = ShapeNetDataset(root=opt.dataset,classification=True,npoints=opt.num_points)test_dataset = ShapeNetDataset(root=opt.dataset,classification=True,split='test',npoints=opt.num_points,data_augmentation=False)elif opt.dataset_type == 'modelnet40':dataset = ModelNetDataset(root=opt.dataset,npoints=opt.num_points,split='trainval')test_dataset = ModelNetDataset(root=opt.dataset,split='test',npoints=opt.num_points,data_augmentation=False)else:exit('wrong dataset type')dataloader = torch.utils.data.DataLoader(dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))testdataloader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))print(len(dataset), len(test_dataset))num_classes = len(dataset.classes)print('classes', num_classes)try:os.makedirs(opt.outf)except OSError:passclassifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)if opt.model != '':classifier.load_state_dict(torch.load(opt.model))optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)classifier.cuda()num_batch = len(dataset) / opt.batchSizefor epoch in range(opt.nepoch):scheduler.step()for i, data in enumerate(dataloader, 0):points, target = datatarget = target[:, 0]points = points.transpose(2, 1) # (batch,3,1024) points = transform_points(points) # (batch,3,1024) ->(1,batch,1024,3)poi