热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【3D目标检测】PointNet

目录概述细节相关工作网络结构点云及点云的性质针对点云性质的处理简单实现概述首先,本文是基于点云,并且直接处理点云数据的3D目标检测网络本文提出了一个简




目录


  • 概述
  • 细节
    • 相关工作
    • 网络结构
    • 点云及点云的性质
    • 针对点云性质的处理

  • 简单实现



概述

首先,本文是基于点云,并且直接处理点云数据的3D目标检测网络
本文提出了一个简单的网络,考虑点云的特性,直接处理点云数据,用于分类或者分割等任务。


细节

相关工作

传统方法:手工提取点云特征,存在很多局限性,如对于特定任务,不容易找到优质的特征,并且模型的泛化能力较差。
深度学习方法:1、将点云体素化,通过3D卷积处理,体素化会导致元素的稀疏(很多体素是空的),并且3D卷积计算量消耗大。2、将点云转换成2D数据,然后采用2D卷积以及相应的方法处理,这个方法在特定的任务上(分类、检索)取得了比较好的结果,但是对另一些任务(分割、形状补全)就不太适用。


网络结构


  1. 输入一个样本对应的点云数据,这里是n*3也就是n个点,每个点有三个特征
  2. 输入数据经过一个T-Net进行输入空间的对齐,保证点云的变换不变性。相当于是做了一次变换,所以输入输出的shape是不变的
  3. 经过多个MLP进行特征提取,再次使用T-Net进行特征空间的对齐
  4. 对于每个特征维度使用max pool得到全局特征,也就是进行了特征聚合操作
  5. 对于分类问题,经过MLP直接输出k个类别的类别概率,对于分割问题,先进行全局信息与局部分析的融合,然后输出n*m的矩阵,表示n个点对应m个分类类别的概率。

在这里插入图片描述


点云及点云的性质

点云:点云就是一系列的点的集合,每个点可以包含若干个属性,最基本的就是空间坐标(x,y,z),当然也可以有一些其他的属性。第一张图是点云的原始数据,这里是包括了6个属性,前三个是坐标,后三个是其他属性;第二张图是这份点云数据的可视化。
在这里插入图片描述
在这里插入图片描述

点云的性质:


  • 非结构化:点云是非结构的,所以无法直接使用卷积神经网络处理。
  • 无序性:点云是无序的,假设我们的点云有N个点,这N个点的任意排序都能正常表示点云,这也要求我们的网络对于任意顺序的点的输入都有一个相同的结果。
  • 变换不变性:对点云进行平移、旋转等变换,应该是不改变点云分类或者分割的结果的。
  • 点的交互:点云中的点不是孤立的,他和他周围的点是一个局部,是存在联系的,所以这个联系,以及局部与局部之间的联系都是需要网络考虑的。

针对点云性质的处理

针对非结构化的特点:作者直接使用了MLP处理

针对无序性的特点: 首先无序背景采用主流的方法会有什么问题呢?我们打个比方,有三个数据,每个数据都有三个特征,我们要将他们聚合成一个数据,一个简单的方法就是每个数据都取一部分,但是每个数据取多少是我们学习得到的,而这个数据是定死的,假如说第一个数据取20%,第二个数据取50%,第三个数据取30%,那么三种数据可以有6种排列方式,对应我们的聚合方法就会产生6种结果,这不是一个无序数据应该有的,也就是我们的聚合方式是存在问题的。无序背景下特征提取主流的解决方案有三种:1、排序;2、RNN处理;3、使用一个简单的对称函数。作者实验下来,第三种效果最好。主要解释下第三种,什么是对称函数。对称函数其实就是对于输入顺序无所谓的函数,比如max函数,min函数,sum函数之类的,如采用max函数,对于每一维度的特征,我们都取最大的一个值作为聚合数据的特征,无论输入数据的顺序怎么变换,都可以得到相同的结果。

针对变换不变性的特点:一个直观的想法就是在进行特征提取之前进行标准空间的对齐,这样的话,不管输入经过了怎样的变换,我们统一将它转换成标准的状态,然后再进行特征提取。作者这边还额外添加了一个特征空间的对齐操作,这两个对齐网络是额外训练的简单网络(T-Net)。

针对点交互的特点:作者这边采用的方式是全局信息与个体信息的融合,首先使用max pool得到全局的点云向量,然后将全局的点云向量与每个点的特征进行融合,接着使用一个MLP加深这种融合(减小语义沟壑),最后用融合之后的特征去做分割。但是,这个操作其实是因为分割任务的特殊性所以需要的,也就是说,从全局的角度来看,是没有将当前点和周围点做融合,没有利用到局部信息的。


简单实现

使用PointNet实现分类,参考代码:链接,conv1d的理解参考:链接

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# 构造输入变换矩阵
class STN3d(nn.Layer):
def __init__(self,channel=3):
super(STN3d, self).__init__()
self.conv1 = nn.Conv1D(channel, 64, 1)
self.conv2 = nn.Conv1D(64, 128, 1)
self.conv3 = nn.Conv1D(128, 1024, 1)
self.fc1=nn.Linear(1024,512)
self.fc2=nn.Linear(512,256)
self.fc3=nn.Linear(256,9)
self.bn1 = nn.BatchNorm1D(64)
self.bn2 = nn.BatchNorm1D(128)
self.bn3 = nn.BatchNorm1D(1024)
self.bn4 = nn.BatchNorm1D(512)
self.bn5 = nn.BatchNorm1D(256)
self.relu=nn.ReLU()
def forward(self,x):
# x:1,3,point_nums
B, D, N = x.shape
x = self.relu(self.bn1(self.conv1(x))) # x:1,3,point_nums ->x:1,64,point_nums
x = self.relu(self.bn2(self.conv2(x))) # x:1,64,point_nums ->x:1,128,point_nums
x = self.relu(self.bn3(self.conv3(x))) # x:1,128,point_nums ->x:1,1024,point_nums
x = paddle.max(x, 2, keepdim=True) # x:1,1024,point_nums ->x:1,1024,1
x = paddle.flatten(x, 1) # x:1,1024,1 -> x:1,1024
x=self.relu(self.bn4(self.fc1(x))) # x:1,1024 -> x:1,512
x=self.relu(self.bn5(self.fc2(x))) # x:1,512 -> x:1,256
x=self.fc3(x) # x:1,256 -> x:1,9
iden = paddle.to_tensor([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(paddle.float32).reshape([1,9]).expand([B,9]) # iden:1,9-> iden:B,9
x=x+iden
x=x.reshape([-1,3,3])
return x
# 构造特征变换矩阵
class STNkd(nn.Layer):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = nn.Conv1D(k, 64, 1)
self.conv2 = nn.Conv1D(64, 128, 1)
self.conv3 = nn.Conv1D(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k*k)
self.bn1 = nn.BatchNorm1D(64)
self.bn2 = nn.BatchNorm1D(128)
self.bn3 = nn.BatchNorm1D(1024)
self.bn4 = nn.BatchNorm1D(512)
self.bn5 = nn.BatchNorm1D(256)
self.relu = nn.ReLU()
self.k=k
def forward(self, x):
# x:1,64,point_nums
B, D, N = x.shape
x = self.relu(self.bn1(self.conv1(x))) # x:1,64,point_nums ->x:1,64,point_nums
x = self.relu(self.bn2(self.conv2(x))) # x:1,64,point_nums ->x:1,128,point_nums
x = self.relu(self.bn3(self.conv3(x))) # x:1,128,point_nums ->x:1,1024,point_nums
x = paddle.max(x, 2, keepdim=True) # x:1,1024,point_nums ->x:1,1024,1
x = paddle.flatten(x, 1) # x:1,1024,1 -> x:1,1024
x = self.relu(self.bn4(self.fc1(x))) # x:1,1024 -> x:1,512
x = self.relu(self.bn5(self.fc2(x))) # x:1,512 -> x:1,256
x = self.fc3(x) # x:1,256 -> x:1,9
iden = paddle.eye(self.k,self.k).astype(paddle.float32).reshape([1, self.k*self.k]).expand(
[B, self.k*self.k]) # iden:1,64-> iden:B,64
x = x + iden
x = x.reshape([-1, self.k,self.k])
return x
class PointNet(nn.Layer):
# 特征数和分类的类别
def __init__(self,channel=3,classes=3):
super(PointNet, self).__init__()
self.stn = STN3d(channel=channel)
self.fstn = STNkd(k=64)
self.conv1 = nn.Conv1D(channel,64,1)
self.conv2 = nn.Conv1D(64, 128, 1)
self.conv3 = nn.Conv1D(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, classes)
self.bn1 = nn.BatchNorm1D(64)
self.bn2 = nn.BatchNorm1D(128)
self.bn3 = nn.BatchNorm1D(1024)
self.bn4 = nn.BatchNorm1D(512)
self.bn5 = nn.BatchNorm1D(256)
self.relu=nn.ReLU()
def forward(self,x):
# x:1,3,point_nums
B,D,N=x.shape
# 输入空间对齐
trans = self.stn(x)
x=paddle.transpose(x,[0,2,1]) # x:1,3,point_nums ->x:1,point_nums,3
x=paddle.matmul(x,trans)
x = paddle.transpose(x, [0,2, 1]) # x:1,point_nums,3 ->x:1,3,point_nums
x=self.relu(self.bn1(self.conv1(x))) # x:1,3,point_nums ->x:1,64,point_nums
# 特征空间对齐
trans_feat = self.fstn(x)
x = paddle.transpose(x, [0, 2, 1]) # x:1,3,point_nums ->x:1,point_nums,3
x = paddle.matmul(x, trans_feat)
x = paddle.transpose(x, [0, 2, 1]) # x:1,point_nums,3 ->x:1,3,point_nums
x=self.relu(self.bn2(self.conv2(x))) # x:1,64,point_nums ->x:1,128,point_nums
x=self.relu(self.bn3(self.conv3(x))) # x:1,128,point_nums ->x:1,1024,point_nums
# 沿着点云数方向做求最大值,相当于是将N个样本聚合成一个,得到一个全局的点云特征
x = paddle.max(x, 2, keepdim=True) # x:1,1024,point_nums ->x:1,1024,1
# 得到全局的特征向量
x=paddle.flatten(x,1) # x:1,1024,1 -> x:1,1024
# 得到分类结果
x=self.relu(self.bn4(self.fc1(x))) # x:1,1024 -> x:1,512
x=self.relu(self.bn5(self.fc2(x))) # x:1,512 -> x:1,256
x=self.fc3(x)
x=F.softmax(x,1)
return x

def main():
x = paddle.randn([1, 3, 100])
pointnet = PointNet(3,3)
print(pointnet(x))
if __name__ == '__main__':
main()






推荐阅读
author-avatar
扫地僧2502896033
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有