深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator

Posted 神机喵算

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator相关的知识,希望对你有一定的参考价值。

你是否思考过TensorFlow的tutorial和其背后的“魔力”?希望这篇文章至少能给你思考的正确方向。

TensorFlow的基本概念可以去查看TensorFlow官方文档。这里将帮你更好的理解中estimator的工作原理,并指导你构建适合自己特定应用的estimator。

BaseEstimator和Estimator的理解

BaseEstimator是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit()partial_fit()evaluate()predict()的基本功能,处理不同类型的输入数据批量拉取(Note:未来learn.DataFrame将替代DataFeeder)。它通过dtypes来检查输入数据的兼容,考虑输入数据是否稀疏需要使用estimators.tensor_signature

BaseEstimator为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑。_get_train_ops()_get_eval_ops()_get_predict_ops()放在子类中实现,给Estimator自定义带来了更大的自由。BaseEstimator也是分布式的。

TensorFlow模块中Estimator的实现给我们重写BaseEstimator子类提供了很好的范本。
例如,Estimator中的_get_train_ops()载入featurestargets作为输入,返回训练Operation和损失Tensor的一个tuple。如果你想完成自己的estimator,并且用于非监督机器学习训练,这时你就可以自由决定targets是否可忽略。

类似地,子类中的_get_eval_ops()可自定义metric来评估每步的训练。在TensorFlow的high-level模块中可发现一打适用的metric。它们会返回Tensor对象的字典,表示指定metric的评价ops。

_get_predict_ops()可实现自定义的prediction,例如 概率 v.s. 实际预测输出。它将返回一个Tensor或者Tensor对象的字典,表示预测ops。你可以很轻松的使用父类的predict()函数实现像transform()的功能。

Estimator示例

逻辑回归(LogisticRegressor)

Estimator已经提供了自定义estimator大部分实现。例如,LogisticRegressor仅需实现自己的metric即可,比如AUC,accuracy,precision和recall。开发者使用LogisticRegressor子类即可实现二值分类问题。

随机森林(TensorForestEstimator)

TensorForestEstimator已经增加到TensorFlow Learn。contrib.tensor_forest详细的实现了随机森林算法(Random Forests)评估器,并对外提供high-level API使得开发者构建随机森林评估器更简单。

例如,开发者只需传入params到构造器,params使用params.fill()来填充,而不用传入所有的超参数,Tensor Forest自己的RandomForestGraphs使用这些参数来构建整幅图。

class TensorForestEstimator(estimator.BaseEstimator):
  """An estimator that can train and evaluate a random forest."""

  def __init__(self, params, device_assigner=None, model_dir=None,               graph_builder_class=tensor_forest.RandomForestGraphs,               master='', accuracy_metric=None,               tf_random_seed=None, verbose=1,               config=None):
    self.params = params.fill()

随机森林算法的接口实现有许多细节,_get_predict_ops()利用tensor_forest.RandomForestGraphs来构建随机森林图,调用graph_builder.inference_graph来获取预测ops。

def _get_predict_ops(self, features):
    graph_builder = self.graph_builder_class(
        self.params, device_assigner=self.device_assigner, training=False,
        **self.construction_args)
    features, spec = data_ops.ParseDataTensorOrDict(features)
   return graph_builder.inference_graph(features, data_spec=spec)

类似地,使用graph_builder.training_loss来实现_get_train_ops()。注意,TensorForestEstimator使用了tensor_forest.data.data_ops的模块功能,比如 ParseDataTensorOrDictParseLabelTensorOrDict解析输入特征和标签。

其它用例

K-means聚类的estimator刚加入项目,放在contrib.factorization.python.ops.kmeans。 更多的例子可以在learn.estimators中找到。

强烈推荐你领悟代码整体结构,开始实现自己的estimator之旅!

参考:http://terrytangyuan.github.io/2016/03/14/scikit-flow-intro


若发现以上文章有任何不妥,请联系我。


以上是关于深度学习深度解析TensorFlow组件Estimator:构建自定义Estimator的主要内容,如果未能解决你的问题,请参考以下文章

『深度长文』Tensorflow代码解析

TensorFlow2 深度学习实战(十四):YOLOv4目标检测算法解析

TensorFlow2 深度学习实战(十四):YOLOv4目标检测算法解析

阿里PAI深度学习组件:Tensorflow实现图片智能分类实验

代码解析深度学习系统编程模型:TensorFlow vs. CNTK

学习TF:《TensorFlow技术解析与实战》PDF+代码