当前位置: 首页 > news >正文

rk3588移植部署pointnet

目录

pointnet 网络模型修改

pointnet训练

pointnet 导出pth, onnx模型

onnx模型推理测试

onnx模型转换rknn模型

rknn模型部署推理测试rk3588点云分类

推理结果

优化方向

onnx, rknn模型下载


pointnet 网络模型修改

模型修改比较简单,其目的在于适配rknn支持的输入张量格式,由于rknn目前只支持图像格式的张量输入,nhwc,nchw 其他格式不支持。然而pointnet的模型输入张量是16x2048x3, 其中16是batch,2048是固定点云数,3是通道数,因此需要做张量转换,维度扩展,把16x2048x3,扩展为1x16x2048x3以适配nhwc格式。

代码如下:

需要修改model.py中网络模型定义部分的前向传播,假设输入的张量维度是1x16x2048x3。

去掉第一维,交换2,3维度。转为推理模型可以识别的张量格式。

class PointNetCls(nn.Module):def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):#print("input x.shape")#print(x.shape)x = x.squeeze(0)  # -> (batch, 1024, 3)x = x.permute(0, 2, 1)  # -> (batch, 3, 1024)#print("trans x.shape")#print(x.shape)x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1), trans, trans_feat

pointnet训练

训练代码如下:test_classification_nhwc.py

其中需要注意的部分,在前向传播前,将数据集的点云张量格式转换维度,适配nhwc模型输入如下:

 points = transform_points(points) # (batch,3,1024) ->(1,batch,1024,3)
 其中转换点维度函数如下:

def transform_points(points):
    # points形状: (32,3,1024)
    points = points.unsqueeze(0)  # -> (1,32,3,1024)
    points = points.permute(0,1,3,2)  # -> (1,32,1024,3)
    return points.contiguous()

把原始的点云输入张量batchxchannelsxpoint_num转换为 1xbatchxchannelsxpoint_num

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
from pointnet.model import PointNetCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdmdef parse_args():'''PARAMETERS'''parser = argparse.ArgumentParser()parser.add_argument('--batchSize', type=int, default=16, help='input batch size')parser.add_argument('--num_points', type=int, default=2500, help='input batch size')parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)parser.add_argument('--nepoch', type=int, default=250, help='number of epochs to train for')parser.add_argument('--outf', type=str, default='cls', help='output folder')parser.add_argument('--model', type=str, default='', help='model path')parser.add_argument('--dataset', type=str, required=True, help="dataset path")parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")parser.add_argument('--feature_transform', action='store_true', help="use feature transform")return parser.parse_args()def transform_points(points):# points形状: (32,3,1024)points = points.unsqueeze(0)  # -> (1,32,3,1024)points = points.permute(0,1,3,2)  # -> (1,32,1024,3)return points.contiguous()def main(args):opt = parse_args()print(opt)blue = lambda x: '\033[94m' + x + '\033[0m'opt.manualSeed = random.randint(1, 10000)  # fix seedprint("Random Seed: ", opt.manualSeed)random.seed(opt.manualSeed)torch.manual_seed(opt.manualSeed)if opt.dataset_type == 'shapenet':dataset = ShapeNetDataset(root=opt.dataset,classification=True,npoints=opt.num_points)test_dataset = ShapeNetDataset(root=opt.dataset,classification=True,split='test',npoints=opt.num_points,data_augmentation=False)elif opt.dataset_type == 'modelnet40':dataset = ModelNetDataset(root=opt.dataset,npoints=opt.num_points,split='trainval')test_dataset = ModelNetDataset(root=opt.dataset,split='test',npoints=opt.num_points,data_augmentation=False)else:exit('wrong dataset type')dataloader = torch.utils.data.DataLoader(dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))testdataloader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))print(len(dataset), len(test_dataset))num_classes = len(dataset.classes)print('classes', num_classes)try:os.makedirs(opt.outf)except OSError:passclassifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)if opt.model != '':classifier.load_state_dict(torch.load(opt.model))optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)classifier.cuda()num_batch = len(dataset) / opt.batchSizefor epoch in range(opt.nepoch):scheduler.step()for i, data in enumerate(dataloader, 0):points, target = datatarget = target[:, 0]points = points.transpose(2, 1) # (batch,3,1024) points = transform_points(points) # (batch,3,1024) ->(1,batch,1024,3)poi
http://www.dtcms.com/a/436117.html

相关文章:

  • 网站制作 网站建设 杭州上海名企
  • 谁用fun域名做网站了襄樊网站制作公司
  • php学建网站摄影标志logo设计欣赏
  • 男的怎么做直播网站wordpress json接口
  • 服装生产工厂管理系统是什么?主要有哪几种核心功能?
  • 浙江省建筑诚信平台查询系统网站meta 优化建议
  • 免费做课设的网站一个商城网站多少钱
  • 开网站流程wordpress.备份
  • JDK1.8下载安装使用教程,图文教程(超详细)
  • 个人网站建设方法和过程聊城专业网站制作公司
  • 合肥网站建设方案id怎么自动导入wordpress
  • Matlab通过GUI实现点云的GICP配准
  • 数字化ERP“一图四清单”战略执行体系
  • 每日一练【约瑟夫环问题】
  • 公司网站推广计划书怎么做网络工程好就业吗
  • 找网站网站防止镜像
  • 监理网站网站ipv6改造怎么做 网页代码
  • 新公司如何做网站wordpress文本块表格
  • 无锡富通电力建设有限公司网站html个人主页制作
  • 重庆微信营销网站建设wordpress用户导入数据库表
  • 网站分类导航代码企业建设网站哪家好
  • ID 生成方案
  • h5开发环境搭建重庆网站seo搜索引擎优化
  • 小程序api的使用搜索引擎排名优化方案
  • 网站怎么才能被搜到微网站排版
  • 建设银行积分网站宿迁网站推广公司
  • 企业电子商务网站的建设方式logo设计公司标志
  • [论文阅读] AI+SRE(网站可靠性工程) | 字节跳动ErrorPrism:微服务错误追踪准确率97%!告别日志“一团乱麻”
  • 打开无忧管理后台网站网站建设 中企动力 常州
  • 装饰公司网站模板17网一起做网店网站