使用TensorFlow Object Detection API确定最大批量大小
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用TensorFlow Object Detection API确定最大批量大小相关的知识,希望对你有一定的参考价值。
TF Object Detection API默认抓取所有GPU内存,因此很难说我可以进一步增加批量大小。通常我会继续增加它,直到我收到CUDA OOM错误。
另一方面,PyTorch默认不会占用所有GPU内存,因此很容易看到我剩下的百分比,没有所有的试验和错误。
有没有更好的方法来确定我丢失的TF对象检测API的批量大小?像allow-growth
的model_main.py
旗帜?
我一直在寻找源代码,我发现没有与此相关的FLAG。
但是,在model_main.py
的文件https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py中,您可以找到以下主要函数定义:
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...
我们的想法是以类似的方式修改它,例如以下方式:
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
所以,添加config_proto
和改变config
,但保持所有其他事情相等。
此外,allow_growth
使程序使用尽可能多的GPU内存。所以,根据你的GPU,你最终可能会吃掉所有的内存。在这种情况下,您可能想要使用
config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
它定义了要使用的内存部分。
希望这有所帮助。
如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何FLAG。除非FLAG
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
意味着与此相关的事情。但我不认为这是因为它似乎在model_lib.py
它与火车,评估和推断配置有关,而不是GPU使用配置。
以上是关于使用TensorFlow Object Detection API确定最大批量大小的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow object detection API应用一
TensorFlow Object Detection API使用问题小记
TensorFlow models - object detection API 安装
在使用 TensorFlow Object Detection API 训练 Mask RCNN 时,“损失”是啥?