动态编辑用于 TensorFlow 对象检测的管道配置
Posted
技术标签:
【中文标题】动态编辑用于 TensorFlow 对象检测的管道配置【英文标题】:Dynamically Editing Pipeline Config for Tensorflow Object Detection 【发布时间】:2019-08-14 20:58:54 【问题描述】:我正在使用 tensorflow 对象检测 API,我希望能够在 python 中动态编辑配置文件,如下所示。我想在 python 中使用协议缓冲区库,但我不知道该怎么做。
model
ssd
num_classes: 1
image_resizer
fixed_shape_resizer
height: 300
width: 300
feature_extractor
type: "ssd_inception_v2"
depth_multiplier: 1.0
min_depth: 16
conv_hyperparams
regularizer
l2_regularizer
weight: 3.99999989895e-05
initializer
truncated_normal_initializer
mean: 0.0
stddev: 0.0299999993294
activation: RELU_6
batch_norm
decay: 0.999700009823
center: true
scale: true
epsilon: 0.0010000000475
train: true
...
...
是否有一种简单/简便的方法可以将 image_resizer -> fixed_shape_resizer 中的高度等字段的特定值从 300 更改为 500?并用修改后的值写回文件而不更改其他任何内容?
编辑: 尽管@DmytroPrylipko 提供的答案适用于配置中的大多数参数,但我在“复合字段”方面遇到了一些问题..
也就是说,如果我们有这样的配置:
train_input_reader:
label_map_path: "/tensorflow/data/label_map.pbtxt"
tf_record_input_reader
input_path: "/tensorflow/models/data/train.record"
我添加这一行来编辑 input_path:
pipeline_config.train_input_reader.tf_record_input_reader.input_path = "/tensorflow/models/data/train100.record"
它抛出错误:
TypeError: Can't set composite field
【问题讨论】:
【参考方案1】:我发现这是一种有用的方法来覆盖对象检测pipeline.config
:
from object_detection.utils import config_util
from object_detection import model_lib_v2
PIPELINE_CONFIG_PATH = 'path_to_your_pipeline.config'
# Load the pipeline config as a dictionary
pipeline_config_dict = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG_PATH)
# OVERRIDE EXAMPLES
# Example 1: Override the train tfrecord path
pipeline_config_dict['train_input_config'].tf_record_input_reader.input_path[0] = 'your/override/path/to/train.record'
# Example 2: Override the eval tfrecord path
pipeline_config_dict['eval_input_config'].tf_record_input_reader.input_path[0] = 'your/override/path/to/test.record'
# Convert the pipeline dict back to a protobuf object
pipeline_config = config_util.create_pipeline_proto_from_configs(pipeline_config_dict)
# EXAMPLE USAGE:
# Example 1: Run the object detection train loop with your overrides (has to be string representation)
model_lib_v2.train_loop(config_override=str(pipeline_config))
# Example 2: Save the pipeline config to disk
config_util.save_pipeline_config(config, 'path/to/save/new/pipeline.config)
【讨论】:
【参考方案2】:这与上面的代码相同,只是做了一些小改动以适应 tensorflow V2。
import argparse
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()
def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.io.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300
config_text = text_format.MessageToString(pipeline_config)
with tf.io.gfile.GFile(args.output, "wb") as f:
f.write(config_text)
if __name__ == '__main__':
main()
【讨论】:
【参考方案3】:是的,使用 Protobuf Python API 非常简单:
edit_pipeline.py:
import argparse
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()
def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300
config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(args.output, "wb") as f:
f.write(config_text)
if __name__ == '__main__':
main()
我调用脚本的方式:
TOOL_DIR=tool/tf-models/research
(
cd $TOOL_DIR
protoc object_detection/protos/*.proto --python_out=.
)
export PYTHONPATH=$PYTHONPATH:$TOOL_DIR:$TOOL_DIR/slim
python3 edit_pipeline.py pipeline.config pipeline_new.config
复合字段
如果出现重复字段,则必须将它们视为数组(例如使用extend()
、append()
方法):
pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'
Eval 输入阅读器错误
这是尝试编辑复合字段的常见错误。 ("no attribute tf_record_input_reader found" in case of eval_input_reader)
@latida 的回答中提到了这一点。 通过将其设置为数组字段来解决此问题。
pipeline_config.eval_input_reader[0].label_map_path = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path
【讨论】:
非常感谢您的回答,但我遇到了“复合字段”的问题。您能否建议我如何解决这个问题?我已经用更多细节更新了这个问题..再次感谢..pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'
。请参阅更新的答案。更多关于复合字段:***.com/questions/18376190/…
你太棒了!最后复合字段问题解决了!非常感谢!!
对我来说,更改 input_path 和 label_map_path 对 train_input_reader 效果很好,但在 eval_input_reader 的情况下显示“找不到属性 tf_record_input_reader”的错误
如果您想用 object_detection.protos.image_resizer_pb2.FixedShapeResizer 替换为 object_detection.protos.image_resizer_pb2.KeepAspectRatioResizer,那么复合字段呢?【参考方案4】:
pipeline_config.eval_input_reader[0].label_map_path = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path
【讨论】:
这不是问题的答案。您可以对现有问题添加评论以上是关于动态编辑用于 TensorFlow 对象检测的管道配置的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow对象检测管道配置中data_augmentation_options的可能值是什么?
TensorFlow对象检测配置文件中的“num_examples:2000”是啥意思?