无法使用 sklearn 的 GridSearchCV 运行 tflearn

Posted

技术标签:

【中文标题】无法使用 sklearn 的 GridSearchCV 运行 tflearn【英文标题】:Cannot run tflearn with sklearn's GridSearchCV 【发布时间】:2016-12-22 10:27:58 【问题描述】:

我打算对 tflearn 模型的超参数执行网格搜索。看来tflearn.DNN生成的模型不符合sklearn的GridSearchCV期望:

from sklearn.grid_search import GridSearchCV
import tflearn
import tflearn.datasets.mnist as mnist
import numpy as np

X, Y, testX, testY = mnist.load_data(one_hot=True)

encoder = tflearn.input_data(shape=[None, 784])
encoder = tflearn.fully_connected(encoder, 256)
encoder = tflearn.fully_connected(encoder, 64)

# Building the decoder
decoder = tflearn.fully_connected(encoder, 256)
decoder = tflearn.fully_connected(decoder, 784)

# Regression, with mean square error
net = tflearn.regression(decoder, optimizer='adam', learning_rate=0.01,
                         loss='mean_square', metric=None)

model = tflearn.DNN(net, tensorboard_verbose=0)

grid_hyperparams = 'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)
grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2)
grid.fit(X, X)

我得到错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-3-fd63245cd0a3> in <module>()
     22 grid_hyperparams = 'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)
     23 grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2)
---> 24 grid.fit(X, X)
     25 
     26 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in fit(self, X, y)
    802 
    803         """
--> 804         return self._fit(X, y, ParameterGrid(self.param_grid))
    805 
    806 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in _fit(self, X, y, parameter_iterable)
    539                                          n_candidates * len(cv)))
    540 
--> 541         base_estimator = clone(self.estimator)
    542 
    543         pre_dispatch = self.pre_dispatch

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe)
     45                             "it does not seem to be a scikit-learn estimator "
     46                             "as it does not implement a 'get_params' methods."
---> 47                             % (repr(estimator), type(estimator)))
     48     klass = estimator.__class__
     49     new_object_params = estimator.get_params(deep=False)

TypeError: Cannot clone object '<tflearn.models.dnn.DNN object at 0x7fead09948d0>' (type <class 'tflearn.models.dnn.DNN'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.

知道如何获得适合 GridSearchCV 的对象吗?

【问题讨论】:

【参考方案1】:

我没有使用 tflearn 的经验,但我确实有一些 Python 和 sklearn 的基本背景。从 *** 屏幕截图中的错误来看,tflearn **models **没有与 scikit-learn 估计器相同的方法或属性。这是可以理解的,因为它们不是 scikit-learn 估计器。

Sklearn 的网格搜索 CV 仅适用于与 scikit-learn 估计器具有相同方法和属性的对象(例如具有 fit() 和 predict() 方法)。如果您打算使用 sklearn 的网格搜索,则必须围绕 tflearn 模型编写自己的包装器,以使其作为 sklearn 估计器的替代品,这意味着您必须编写自己的具有相同的类方法与任何其他 scikit-learn 估计器一样,但使用 tflearn 库来实际实现这些方法。

为此,请了解基本 scikit-learn 估计器的代码(最好是您熟悉的那个),并查看 fit()、predict()、get_params() 等方法对对象及其对象的实际作用内部结构。然后使用 tflearn 库编写您自己的类。

首先,快速的 Google 搜索显示此存储库是“用于 tensorflow 框架的精简 scikit-learn 样式包装器”:DSLituiev/tflearn (https://github.com/DSLituiev/tflearn)。我不知道这是否可以替代 Grid Search,但值得一看。

【讨论】:

以上是关于无法使用 sklearn 的 GridSearchCV 运行 tflearn的主要内容,如果未能解决你的问题,请参考以下文章

从命令行运行脚本时忽略 sklearn Gridsearch 中 n_jobs = -1 的警告不使用 warnings.simplefilter('ignore')

ValueError:使用 GridSearch 参数时估计器 CountVectorizer 的参数模型无效

如何在CDH中使用PySpark分布式运行GridSearch算法

与 RFECV 结合时如何在 Gridsearch 中使用“max_features”?

sklearn.pipeline.Pileline

我无法在 gridsearch 中添加优化器参数