CVPR2020 RandLA-Net 代码阅读
Posted jiangxiaoju
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CVPR2020 RandLA-Net 代码阅读相关的知识,希望对你有一定的参考价值。
论文:RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds
代码:Github
记录一下RandLANet代码的阅读笔记,如有出错欢迎讨论。
一、RandLA-Net网络结构
下图这个是本地聚合模块
网络架构的详细信息。整个网络就是由本地聚合模块堆叠而成的。
二、项目结构
整个项目文件非常简洁。
文件 | 作用 |
---|---|
helper_tf_util.py | 封装了一些卷积池化操作代码 |
helper_tool.py | 有训练时各个数据集所用到的一些参数信息,还有一些预处理数据时的一些模块。 |
main_*.py | 训练对应数据的主文件 |
RandLANet.py | 定义网络的主题结构 |
tester_*.py | 测试对应数据的文件,该文件在main_*.py中被调用 |
utils | 改文件夹里面有对数据集预处理的模块以及KNN模块。 |
三、代码解读
3.1 main_S3DIS
main_*.py
代码结构都基本一样就以main_S3DIS.py
为例,重点记录一下RandLANet.py
的代码。
在main_S3DIS.py
里面有个S3DIS
类,初始话参数需要传入一个指定训练的数据集。初始化了一些变量。例如lable
等信息,之后。
class S3DIS:
def __init__(self, test_area_idx):
self.name = 'S3DIS'
self.path = '/data/S3DIS'
self.label_to_names = 0: 'ceiling', 1: 'floor', 2: 'wall', 3: 'beam', 4: 'column', 5: 'window', 6: 'door',
7: 'table', 8: 'chair', 9: 'sofa', 10: 'bookcase', 11: 'board', 12: 'clutter'
self.num_classes = len(self.label_to_names)
self.label_values = np.sort([k for k, v in self.label_to_names.items()])
self.label_to_idx = l: i for i, l in enumerate(self.label_values)
self.ignored_labels = np.array([])
self.val_split = 'Area_' + str(test_area_idx)
self.all_files = glob.glob(join(self.path, 'original_ply', '*.ply'))
# initialize
self.val_proj = []
self.val_labels = []
self.possibility =
self.min_possibility =
self.input_trees = 'training': [], 'validation': []
self.input_colors = 'training': [], 'validation': []
self.input_labels = 'training': [], 'validation': []
self.input_names = 'training': [], 'validation': []
self.load_sub_sampled_clouds(cfg.sub_grid_size)
之后调用调用dataset.init_input_pipeline()
def init_input_pipeline(self):
print('Initiating input pipelines')
cfg.ignored_label_inds = [self.label_to_idx[ign_label] for ign_label in self.ignored_labels]
gen_function, gen_types, gen_shapes = self.get_batch_gen('training')
gen_function_val, _, _ = self.get_batch_gen('validation')
self.train_data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes)
self.val_data = tf.data.Dataset.from_generator(gen_function_val, gen_types, gen_shapes)
self.batch_train_data = self.train_data.batch(cfg.batch_size)
self.batch_val_data = self.val_data.batch(cfg.val_batch_size)
map_func = self.get_tf_mapping2()
self.batch_train_data = self.batch_train_data.map(map_func=map_func)
self.batch_val_data = self.batch_val_data.map(map_func=map_func)
self.batch_train_data = self.batch_train_data.prefetch(cfg.batch_size)
self.batch_val_data = self.batch_val_data.prefetch(cfg.val_batch_size)
iter = tf.data.Iterator.from_structure(self.batch_train_data.output_types, self.batch_train_data.output_shapes)
self.flat_inputs = iter.get_next()
self.train_init_op = iter.make_initializer(self.batch_train_data)
self.val_init_op = iter.make_initializer(self.batch_val_data)
3.2 RandLANet
初始化信息。
def __init__(self, dataset, config):
# 训练的超参数或者配置信息
flat_inputs = dataset.flat_inputs
self.config = config
# Path of the result folder
if self.config.saving:
if self.config.saving_path is None:
self.saving_path = time.strftime('results/Log_%Y-%m-%d_%H-%M-%S', time.gmtime())
else:
self.saving_path = self.config.saving_path
makedirs(self.saving_path) if not exists(self.saving_path) else None
with tf.variable_scope('inputs'):
#### ******
with tf.variable_scope('layers'):
self.logits = self.inference(self.inputs, self.is_training)
with tf.variable_scope('loss'):
#### ******
self.loss = self.get_loss(valid_logits, valid_labels, self.class_weights)
with tf.variable_scope('optimizer'):
#### ******
with tf.variable_scope('results'):
#### ******
tf.summary.scalar('learning_rate', self.learning_rate)
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('accuracy', self.accuracy)
my_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(my_vars, max_to_keep=100)
c_proto = tf.ConfigProto()
c_proto.gpu_options.allow_growth = True
self.sess = tf.Session(config=c_proto)
self.merged = tf.summary.merge_all()
self.train_writer = tf.summary.FileWriter(config.train_sum_dir, self.sess.graph)
self.sess.run(tf.global_variables_initializer())
网络结构的定义
各位置feature的shape在注释中:
def inference(self, inputs, is_training):
d_out = self.config.d_out
# (?,?,6)
feature = inputs['features']
# (?,?,8)
feature = tf.layers.dense(feature, 8, activation=None, name='fc0')
feature = tf.nn.leaky_relu(tf.layers.batch_normalization(feature, -1, 0.99, 1e-6, training=is_training))
# (?,?,1,8)
feature = tf.expand_dims(feature, axis=2)
# ###########################Encoder############################
f_encoder_list = []
# config.num_layers:[16, 64, 128, 256, 512]
for i in range(self.config.num_layers):
f_encoder_i = self.dilated_res_block(feature, inputs['xyz'][i], inputs['neigh_idx'][i], d_out[i],'Encoder_layer_' + str(i), is_training)
f_sampled_i = self.random_sample(f_encoder_i, inputs['sub_idx'][i])
feature = f_sampled_i
if i == 0:
f_encoder_list.append(f_encoder_i)
f_encoder_list.append(f_sampled_i)
# ###########################Encoder############################
feature = helper_tf_util.conv2d(f_encoder_list[-1], f_encoder_list[-1].get_shape()[3].value, [1, 1], 'decoder_0', [1, 1], 'VALID', True, is_training)
# ###########################Decoder############################
f_decoder_list = []
for j in range(self.config.num_layers):
f_interp_i = self.nearest_interpolation(feature, inputs['interp_idx'][-j - 1])
f_decoder_i = helper_tf_util.conv2d_transpose(tf.concat([f_encoder_list[-j - 2], f_interp_i], axis=3),
f_encoder_list[-j - 2].get_shape()[-1].value, [1,1],'Decoder_layer_' + str(j), [1, 1], 'VALID', bn=True,is_training=is_training)
feature = f_decoder_i
f_decoder_list.append(f_decoder_i)
# ###########################Decoder############################
f_layer_fc1 = helper_tf_util.conv2d(f_decoder_list[-1], 64, [1, 1], 'fc1', [1, 1], 'VALID', True, is_training)
f_layer_fc2 = helper_tf_util.conv2d(f_layer_fc1, 32, [1, 1], 'fc2', [1, 1], 'VALID', True, is_training)
f_layer_drop = helper_tf_util.dropout(f_layer_fc2, keep_prob=0.5, is_training=is_training, scope='dp1')
f_layer_fc3 = helper_tf_util.conv2d(f_layer_drop, self.config.num_classes, [1, 1], 'fc', [1, 1], 'VALID', False,is_training, activation_fn=None)
f_out = tf.squeeze(f_layer_fc3, [2])
return f_out
扩张残差块:dilated_res_block
参数注释:
feacture
:输入的数据xyz
:博主的理解是 个点的xyz
坐标neigh_idx
:k近邻点d_out
:输出通道数is_training
:是否训练
相关细节在注释中。可以参照Dilated Residual Block结构图
def dilated_res_block(self, feature, xyz, neigh_idx, d_out, name, is_training):
# Shared MLP(N,dout/2)
f_pc = helper_tf_util.conv2d(feature, d_out // 2, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
# 局部特征聚合模块( LoscSe,Attenntive Pooling)
f_pc = self.building_block(xyz, f_pc, neigh_idx, d_out // 2, name + 'LFA', is_training)
# Shared MLP (N,2dout)
f_pc = helper_tf_util.conv2d(f_pc, d_out * 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training,activation_fn=None)
# Shared MLP (N,2dout)
shortcut = helper_tf_util.conv2d(feature, d_out * 2, [1, 1], name + 'shortcut', [1, 1], 'VALID', activation_fn=None,bn=True, is_training=is_training)
# sum,lrelu
return tf.nn.leaky_relu(f_pc + shortcut)
building_block
参数同上面的dilated_res_block
类似
def building_block(self, xyz, feature, neigh_idx, d_out, name, is_training):
# LocSE
d_in = feature.get_shape()[-1].value
# Relative Point Position Encoding
f_xyz = self.relative_pos_encoding(xyz, neigh_idx)
f_xyz = helper_tf_util.conv2d(f_xyz, d_in, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
f_neighbours = self.gather_neighbour(tf.squeeze(feature, axis=2), neigh_idx)
f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
# Attentive Pooling
f_pc_agg = self.att_pooling(f_concat, d_out // 2, name + 'att_pooling_1', is_training)
f_xyz = helper_tf_util.conv2d(f_xyz, d_out // 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training)
f_neighbours = self.gather_neighbour(tf.squeeze(f_pc_agg, axis=2), neigh_idx)
f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
f_pc_agg = self.att_pooling(f_concat, d_out, name + 'att_pooling_2', is_training)
return f_pc_agg
relative_pos_encoding
参数解析:
xyz
:个点的xyz
坐标信息neigh_idx
:近邻点
def relative_pos_encoding(self, xyz, neigh_idx):
neighbor_xyz = self.gather_neighbour(xyz, neigh_idx)
xyz_tile = tf.tile(tf.expand_dims(xyz, axis=2), [1, 1, tf.shape(neigh_idx)[-1], 1])
relative_xyz = xyz_tile - neighbor_xyz
relative_dis = tf.reduce_sum(relative_xyz, axis=-1, keepdims=True)
relative_feature = tf.concat([relative_dis, relative_xyz, xyz_tile, neighbor_xyz], axis=-1)
return relative_feature
gather_neighbour
获取各个近邻点数据
@staticmethod
def gather_neighbour(pc, neighbor_idx):
# gather the coordinates or features of neighboring points
batch_size = tf.shape(pc)[0]
num_points = tf.shape(pc)[1]
d = pc.get_shape()[2].value
index_input = tf.reshape(neighbor_idx, shape=[batch_size, -1])
features = tf.batch_gather(pc, index_input)
features = tf.reshape(features, [batch_size, num_points, tf.shape(neighbor_idx)[-1], d])
return features
att_pooling
参数解析:
feature_set
:LocSe模块的输出d_out
:输出的通道数
@staticmethod
def att_pooling(feature_set, d_out, name, is_training):
batch_size = tf.shape(feature_set)[0]
num_points = tf.shape(feature_set)[1]
num_neigh = tf.shape(feature_set)[2]
d = feature_set.get_shape()[3].value
f_reshaped = tf.reshape(feature_set, shape=[-1, num_neigh, d])
att_activation = tf.layers.dense(f_reshaped, d, activation=None, use_bias=False, name=name + 'fc')
att_scores = tf.nn.softmax(att_activation, axis=1)
f_agg = f_reshaped * att_scores
f_agg = tf.reduce_s以上是关于CVPR2020 RandLA-Net 代码阅读的主要内容,如果未能解决你的问题,请参考以下文章
视频去模糊论文阅读-Cascaded Deep Video Deblurring Using Temporal Sharpness Prior
CVPR2020论文分方向整理之检测篇(代码/论文解读/136篇打包下载)
跑通代码---CVPR2020--StegaStamp: Invisible Hyperlinks in Physical Photographs