如何使用作为迭代器的训练数据进行 scikit-learn 网格搜索
Posted
技术标签:
【中文标题】如何使用作为迭代器的训练数据进行 scikit-learn 网格搜索【英文标题】:How can I do a scikit-learn grid search with training data that is an iterator 【发布时间】:2016-08-19 11:39:27 【问题描述】:我正在处理一个文本分类问题,使用如下所示的管道:
self.full_classifier = Pipeline([
('vectorize', CountVectorizer()),
('tf-idf', TfidfTransformer()),
('classifier', SVC(kernel='linear', class_weight='balanced'))
])
完整的语料库太大而无法放入内存,但足够小以至于在矢量化步骤之后我没有内存问题。我可以通过使用成功拟合分类器
self.full_classifier.fit(
self._all_data (max_samples=train_data_length),
self.dataset.head(train_data_length)['target'].values
)
其中 self._all_data 是一个迭代器,它为每个训练示例生成文档(而 self.dataset 仅包含文档 ID 和目标)。在这里,max_samples 是可选的,我使用它对训练/测试数据进行拆分。我现在想使用 gridsearch 来优化参数,为此我使用以下代码:
parameters =
'vectorize__stop_words': (None, 'english'),
'tfidf__use_idf': (True, False),
'classifier__class_weight': (None, 'balanced')
gridsearch_classifier = GridSearchCV(self.full_classifier, parameters, n_jobs=-1)
gridsearch_classifier.fit(self._all_data(), self.dataset['target'].values)
我的问题是这会产生以下错误:
TypeError: Expected sequence or array-like, got <type 'generator'>
回溯指向 gridsearch_classifier.fit 方法(然后进入 scikit 的代码,在 _num_samples(x) 中引发错误。由于可以使用生成器作为输入,我想知道是否还有一种方法可以使用我目前缺少的网格搜索来执行此操作。 任何帮助表示赞赏!
【问题讨论】:
【参考方案1】:除非将生成器具体化为列表。虽然各种拟合方法通常可以构造为一次消耗一个项目,从而接受迭代器,但网格搜索还执行交叉验证并通过索引已实现的集合来生成数据的 cv 拆分。
【讨论】:
谢谢,有道理。我将通过实现一个命中数据库的 getitem 来研究伪造列表以上是关于如何使用作为迭代器的训练数据进行 scikit-learn 网格搜索的主要内容,如果未能解决你的问题,请参考以下文章