深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator
Posted 神机喵算
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator相关的知识,希望对你有一定的参考价值。
你是否思考过TensorFlow的tutorial和其背后的“魔力”?希望这篇文章至少能给你思考的正确方向。
TensorFlow的基本概念可以去查看TensorFlow官方文档。这里将帮你更好的理解中estimator的工作原理,并指导你构建适合自己特定应用的estimator。
BaseEstimator和Estimator的理解
BaseEstimator是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit(),partial_fit(),evaluate()和predict()的基本功能,处理不同类型的输入数据批量拉取(Note:未来learn.DataFrame将替代DataFeeder)。它通过dtypes来检查输入数据的兼容,考虑输入数据是否稀疏需要使用estimators.tensor_signature。
BaseEstimator为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑。_get_train_ops(), _get_eval_ops()和_get_predict_ops()放在子类中实现,给Estimator自定义带来了更大的自由。BaseEstimator也是分布式的。
TensorFlow模块中Estimator的实现给我们重写BaseEstimator子类提供了很好的范本。
例如,Estimator中的_get_train_ops()载入features和targets作为输入,返回训练Operation和损失Tensor的一个tuple。如果你想完成自己的estimator,并且用于非监督机器学习训练,这时你就可以自由决定targets是否可忽略。
类似地,子类中的_get_eval_ops()可自定义metric来评估每步的训练。在TensorFlow的high-level模块中可发现一打适用的metric。它们会返回Tensor对象的字典,表示指定metric的评价ops。
_get_predict_ops()可实现自定义的prediction,例如 概率 v.s. 实际预测输出。它将返回一个Tensor或者Tensor对象的字典,表示预测ops。你可以很轻松的使用父类的predict()函数实现像transform()的功能。
Estimator示例
逻辑回归(LogisticRegressor)
Estimator已经提供了自定义estimator大部分实现。例如,LogisticRegressor仅需实现自己的metric即可,比如AUC,accuracy,precision和recall。开发者使用LogisticRegressor子类即可实现二值分类问题。
随机森林(TensorForestEstimator)
TensorForestEstimator已经增加到TensorFlow Learn。contrib.tensor_forest详细的实现了随机森林算法(Random Forests)评估器,并对外提供high-level API使得开发者构建随机森林评估器更简单。
例如,开发者只需传入params到构造器,params使用params.fill()来填充,而不用传入所有的超参数,Tensor Forest自己的RandomForestGraphs使用这些参数来构建整幅图。
class TensorForestEstimator(estimator.BaseEstimator):
"""An estimator that can train and evaluate a random forest."""
def __init__(self, params, device_assigner=None, model_dir=None, graph_builder_class=tensor_forest.RandomForestGraphs, master='', accuracy_metric=None, tf_random_seed=None, verbose=1, config=None):
self.params = params.fill()
随机森林算法的接口实现有许多细节,_get_predict_ops()利用tensor_forest.RandomForestGraphs来构建随机森林图,调用graph_builder.inference_graph来获取预测ops。
def _get_predict_ops(self, features):
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
features, spec = data_ops.ParseDataTensorOrDict(features)
return graph_builder.inference_graph(features, data_spec=spec)
类似地,使用graph_builder.training_loss来实现_get_train_ops()。注意,TensorForestEstimator使用了tensor_forest.data.data_ops的模块功能,比如 ParseDataTensorOrDict和ParseLabelTensorOrDict解析输入特征和标签。
其它用例
K-means聚类的estimator刚加入项目,放在contrib.factorization.python.ops.kmeans。 更多的例子可以在learn.estimators中找到。
强烈推荐你领悟代码整体结构,开始实现自己的estimator之旅!
参考:http://terrytangyuan.github.io/2016/03/14/scikit-flow-intro
若发现以上文章有任何不妥,请联系我。
以上是关于深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow2 深度学习实战(十四):YOLOv4目标检测算法解析
TensorFlow2 深度学习实战(十四):YOLOv4目标检测算法解析
阿里PAI深度学习组件:Tensorflow实现图片智能分类实验