paddle detection 配置文件怎么实例化的 代码梳理
Posted 东东就是我
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了paddle detection 配置文件怎么实例化的 代码梳理相关的知识,希望对你有一定的参考价值。
我们发现paddledetection只是修改配置文件就可以训练,在代码是怎么实现的,yaml为什么可以自动实例
1.代码梳理
train.py 132行 开始加载配置文件
cfg = load_config(FLAGS.config)
paddet/core/workpace.py
def load_config(file_path):
"""
Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
# load config from file and merge into global config
cfg = _load_config_with_base(file_path)
cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
merge_config(cfg)
return global_config
这段代码很简单,就是检测配置文件的扩展名,然后加载,加载是写在下面的函数的
# parse and load _BASE_ recursively
def _load_config_with_base(file_path):
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
# NOTE: cfgs outside have higher priority than cfgs in _BASE_
if BASE_KEY in file_cfg:
all_base_cfg = AttrDict()
base_ymls = list(file_cfg[BASE_KEY])
for base_yml in base_ymls:
if base_yml.startswith("~"):
base_yml = os.path.expanduser(base_yml)
if not base_yml.startswith('/'):
base_yml = os.path.join(os.path.dirname(file_path), base_yml)
with open(base_yml) as f:
base_cfg = _load_config_with_base(base_yml)
all_base_cfg = merge_config(base_cfg, all_base_cfg)
del file_cfg[BASE_KEY]
return merge_config(file_cfg, all_base_cfg)
return file_cfg
这段代码也很简单,就是循环加载配置文件,因为paddle的配置文件之中是包含多个配置文件的。
最主要的是这句
yaml.load(f, Loader=yaml.Loader)
然后把加载的配置文件,其实就是生成的实例放在全局的字典里
all_base_cfg = merge_config(base_cfg, all_base_cfg)
这个函数很简单
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global global_config
dct = another_cfg or global_config
return dict_merge(dct, config)
就是放在全局字典
2.本文重点解析为什么yaml.load 可以生成实例
首先要知道yaml是可以通过实例生成配置文件,也可以通过配置文件生成实例,这篇文章讲的很好
https://www.jb51.net/article/242838.htm#_lab2_1_3
了解怎么生成后,我们模仿写一个例子,通过配置文件生成实例
import yaml
class Person(object):
def __init__(self, name, age):
self.name = name
self.age = age
def __repr__(self):
return '%s(name=%s, age=%d)' % (self.__class__.__name__, self.name, self.age)
def person_cons(loader, node):
value = loader.construct_mapping(node) # mapping构造器,用于dict
name = value['name']
age = value['age']
return Person(name, age)
yaml.add_constructor(u'!person', person_cons) # 用add_constructor方法为指定yaml标签添加构造器
lily = yaml.load('!person name: Lily, age: 19') #生成实力
print (lily)
file_path="1.yml"
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
print(file_cfg)
1.yaml
TrainDataset:
!person
name: train
age: 11
那么在看paddle,他的注册在哪里做的呢,在PADET\\DATA\\SOURCE\\COCO.PY 中 COCODataSet上有个
@serializable
这是啥意思,装饰器,装饰器不懂的看看这个
http://c.biancheng.net/view/2270.html
那么这个装饰器函数执行了什么,在ppdet\\core\\config\\yaml_helper.py中
def serializable(cls):
"""
Add loader and dumper for given class, which must be
"trivially serializable"
Args:
cls: class to be serialized
Returns: cls
"""
yaml.add_constructor(u'!'.format(cls.__name__),
_make_python_constructor(cls))
yaml.add_representer(cls, _make_python_representer(cls))
return cls
他把类注册进去了,是不是和我们例子的添加构造器一样。我们在看看注册函数
def _make_python_constructor(cls):
def python_constructor(loader, node):
if isinstance(node, yaml.SequenceNode):
args = loader.construct_sequence(node, deep=True)
return cls(*args)
else:
kwargs = loader.construct_mapping(node, deep=True)
try:
return cls(**kwargs)
except Exception as ex:
print("Error when construct instance from yaml config".
format(cls.__name__))
raise ex
return python_constructor
他干啥了,他把参数传进去,生成类的实例,然后返回实例 cls(**kwargs)
3.总结
看到这,我们已经弄明白整个流程了,首先train.py加载的时候会执行init.py ,然后把其中包含的类都加载一遍,如果类上又装饰器,那么装饰器也执行。所以就会把COCODataSet这个类的实例注册到yaml的构造器,然后我们在load的时候直接在构造器里找就好了,那么相应的其他的类也是相同的方式注册的。
4.register
我们发现代码中还有这个装饰器,那么这个装饰器是干嘛的。说白了也是在Init.py初始化把类的实例注册到一个全局的字典里,然后用到那个实例,我们直接去字典里取。和serializable的区别就是,yaml自带的通过配置文件生成类的实例,不懂register的,看看mmdetection中的register,讲的好
https://zhuanlan.zhihu.com/p/355271993
以上是关于paddle detection 配置文件怎么实例化的 代码梳理的主要内容,如果未能解决你的问题,请参考以下文章
paddle detection 配置文件怎么实例化的 代码梳理 -----(yaml)
openVINO+paddle基于 Paddle2ONNX实现Paddle-Detection/OCR/Seg导出
`Segmentation fault` is detected by the operating system
`Segmentation fault` is detected by the operating system