pointNet基于pointNet的三维点云目标分类识别matlab仿真

Posted fpga和matlab

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pointNet基于pointNet的三维点云目标分类识别matlab仿真相关的知识,希望对你有一定的参考价值。

1.软件版本

matlab2021a

2.系统概述

这里,采用的pointnet网络结构如下图所示:

        在整体网络结构中,

      首先进行set abstraction,这一部分主要即对点云中的点进行局部划分,提取整体特征,如图可见,在set abstraction中,主要有Sampling layer、Grouping layer、以及PointNet layer三层构成,sampling layer即完成提取中心点工作,采用fps算法,而在grouping中,即完成group操作,采用mrg或msg方法,最后对于提取出得点,使用pointnet进行特征提取。在msg中,第一层set abstraction取中心点512个,半径分别为0.1、0.2、0.4,每个圈内的最大点数为16,32,128。

Sampling layer

采样层在输入点云中选择一系列点,由此定义出局部区域的中心。采样算法使用迭代最远点采样方法 iterative farthest point sampling(FPS)。先随机选择一个点,然后再选择离这个点最远的点作为起点,再继续迭代,直到选出需要的个数为止相比随机采样,能更完整得通过区域中心点采样到全局点云

Grouping layer

目的是要构建局部区域,进而提取特征。思想就是利用临近点,并且论文中使用的是neighborhood ball,而不是KNN,是因为可以保证有一个fixed region scale,主要的指标还是距离distance。

Pointnet layer

在如何对点云进行局部特征提取的问题上,利用原有的Pointnet就可以很好的提取点云的特征,由此在Pointnet++中,原先的Pointnet网络就成为了Pointnet++网络中的子网络,层级迭代提取特征。

3.部分核心程序

clc;
clear;
close all;
warning off;
addpath(genpath(pwd));
rng('default')
%****************************************************************************
%更多关于matlab和fpga的搜索“fpga和matlab”的CSDN博客:
%matlab/FPGA项目开发合作
%https://blog.csdn.net/ccsss22?type=blog
%****************************************************************************
dsTrain = PtCloudClassificationDatastore('train');
dsVal = PtCloudClassificationDatastore('test');

ptCloud = pcread('Chair.ply');
label = 'Chair';
figure;pcshow(ptCloud)
xlabel("X");ylabel("Y");zlabel("Z");title(label)

dsLabelCounts = transform(dsTrain,@(data)data2 data1.Count);
labelCounts = readall(dsLabelCounts);
labels = vertcat(labelCounts:,1);
counts = vertcat(labelCounts:,2);
figure;histogram(labels);title('class distribution')


rng(0)
[G,classes] = findgroups(labels);
numObservations = splitapply(@numel,labels,G);
desiredNumObservationsPerClass = max(numObservations);
filesOverSample=[];
for i=1:numel(classes)
if i==1
    targetFiles = dsTrain.Files1:numObservations(i);
else
    targetFiles = dsTrain.FilesnumObservations(i-1)+1:sum(numObservations(1:i));
end
% Randomly replicate the point clouds belonging to the infrequent classes
files = randReplicateFiles(targetFiles,desiredNumObservationsPerClass);
filesOverSample = vertcat(filesOverSample,files');
end
dsTrain.Files=filesOverSample;

 

dsTrain.Files = dsTrain.Files(randperm(length(dsTrain.Files)));



dsTrain.MiniBatchSize = 32;
dsVal.MiniBatchSize = dsTrain.MiniBatchSize;


dsTrain = transform(dsTrain,@augmentPointCloud);

data = preview(dsTrain);
ptCloud = data1,1;
label = data1,2;

figure;pcshow(ptCloud.Location,[0 0 1],"MarkerSize",40,"VerticalAxisDir","down")
xlabel("X");ylabel("Y");zlabel("Z");title(label)


minPointCount = splitapply(@min,counts,G);
maxPointCount = splitapply(@max,counts,G);
meanPointCount = splitapply(@(x)round(mean(x)),counts,G);
stats = table(classes,numObservations,minPointCount,maxPointCount,meanPointCount)

numPoints = 1000;
dsTrain = transform(dsTrain,@(data)selectPoints(data,numPoints));
dsVal = transform(dsVal,@(data)selectPoints(data,numPoints));

dsTrain = transform(dsTrain,@preprocessPointCloud);
dsVal = transform(dsVal,@preprocessPointCloud);

data = preview(dsTrain);
figure;pcshow(data1,1,[0 0 1],"MarkerSize",40,"VerticalAxisDir","down");
xlabel("X");ylabel("Y");zlabel("Z");title(data1,2)


inputChannelSize = 3;
hiddenChannelSize1 = [64,128];
hiddenChannelSize2 = 256;
[parameters.InputTransform, state.InputTransform] = initializeTransform(inputChannelSize,hiddenChannelSize1,hiddenChannelSize2);

inputChannelSize = 3;
hiddenChannelSize = [64 64];
[parameters.SharedMLP1,state.SharedMLP1] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);

inputChannelSize = 64;
hiddenChannelSize1 = [64,128];
hiddenChannelSize2 = 256;
[parameters.FeatureTransform, state.FeatureTransform] = initializeTransform(inputChannelSize,hiddenChannelSize,hiddenChannelSize2);

inputChannelSize = 64;
hiddenChannelSize = 64;
[parameters.SharedMLP2,state.SharedMLP2] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);


inputChannelSize = 64;
hiddenChannelSize = [512,256];
numClasses = numel(classes);
[parameters.ClassificationMLP, state.ClassificationMLP] = initializeClassificationMLP(inputChannelSize,hiddenChannelSize,numClasses);

numEpochs = 60;
learnRate = 0.001;
l2Regularization = 0.1;
learnRateDropPeriod = 15;
learnRateDropFactor = 0.5;

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;
avgGradients = [];
avgSquaredGradients = [];

[lossPlotter, trainAccPlotter,valAccPlotter] = initializeTrainingProgressPlot;
% Number of classes
numClasses = numel(classes);
% Initialize the iterations
iteration = 0;
% To calculate the time for training
start = tic;
% Loop over the epochs
for epoch = 1:numEpochs
    
    % Reset training and validation datastores.
    reset(dsTrain);
    reset(dsVal);
    
    % Iterate through data set.
    while hasdata(dsTrain) % if no data to read, exit the loop to start the next epoch
        iteration = iteration + 1;        
        % Read data.
        data = read(dsTrain);        
        % Create batch.
        [XTrain,YTrain] = batchData(data,classes);        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradients, loss, state, acc] = dlfeval(@modelGradients,XTrain,YTrain,parameters,state);
        % L2 regularization.
        gradients = dlupdate(@(g,p) g + l2Regularization*p,gradients,parameters);
        % Update the network parameters using the Adam optimizer.
        [parameters, avgGradients, avgSquaredGradients] = adamupdate(parameters, gradients, ...
            avgGradients, avgSquaredGradients, iteration,learnRate,gradientDecayFactor, squaredGradientDecayFactor);
        % Update the training progress.
        D = duration(0,0,toc(start),"Format","hh:mm:ss");
        title(lossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D))
        addpoints(lossPlotter,iteration,double(gather(extractdata(loss))))
        addpoints(trainAccPlotter,iteration,acc);
        drawnow
    end
    
    % Create confusion matrix 
    cmat = sparse(numClasses,numClasses);
    % Classify the validation data to monitor the tranining process
    while hasdata(dsVal)                
        data = read(dsVal); % Get the next batch of data.
        [XVal,YVal] = batchData(data,classes);% Create batch.        
        % Compute label predictions.
        isTrainingVal = 0; %Set at zero for validation data
        YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
        
        % Choose prediction with highest score as the class label for
        % XTest.
        [~,YValLabel] = max(YVal,[],1);
        [~,YPredLabel] = max(YPred,[],1);
        cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);% Update the confusion matrix
    end
    % Update training progress plot with average classification accuracy.
    acc = sum(diag(cmat))./sum(cmat,"all");
    addpoints(valAccPlotter,iteration,acc);
    % Update the learning rate
    if mod(epoch,learnRateDropPeriod) == 0
        learnRate = learnRate * learnRateDropFactor;
    end   
    reset(dsTrain); % Reset the training data since all the training data were already read 
    % Shuffle the data at every epoch
    dsTrain.UnderlyingDatastore.Files = dsTrain.UnderlyingDatastore.Files(randperm(length(dsTrain.UnderlyingDatastore.Files)));
    reset(dsVal);
end


cmat = sparse(numClasses,numClasses); % Prepare sparse-double variable to do like zeros(2,2)
reset(dsVal); % Reset the validation data
data = readall(dsVal); % Read all validation data
[XVal,YVal] = batchData(data,classes); % Create batch.
% Classify the validation data using the helper function pointnetClassifier
YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
% Choose prediction with highest score as the class label for
% XTest.
[~,YValLabel] = max(YVal,[],1);
[~,YPredLabel] = max(YPred,[],1);

% Collect confusion metrics.
cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
figure;chart = confusionchart(cmat,classes);

acc = sum(diag(cmat))./sum(cmat,"all")






4.仿真结论

 

 

 

 

 5.参考文献

 [1][1] Qi C R ,  Su H ,  Mo K , et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.资源同名下载

以上是关于pointNet基于pointNet的三维点云目标分类识别matlab仿真的主要内容,如果未能解决你的问题,请参考以下文章

论文解读 | F-PointNet, 使用RGB图像和Depth点云深度, 数据的3D目标检测

论文解读F-PointNet 使用RGB图像和Depth点云深度 数据的3D目标检测

论文解读F-PointNet 使用RGB图像和Depth点云深度 数据的3D目标检测

论文解读 | F-PointNet, 使用RGB图像和Depth点云深度, 数据的3D目标检测

三维深度学习Pytorch-PointNet系列之win10下环境安装与demo运行

MATLAB教程案例62使用matlab实现基于PointNet++网络的点云数据分类仿真分析