多标签分类的交叉验证错误

Posted

技术标签:

【中文标题】多标签分类的交叉验证错误【英文标题】:Error with cross validation on a multilabel classification 【发布时间】:2015-07-07 04:10:09 【问题描述】:

我正在使用“multiclass.OneVsRestClassifier”和“cross_validation.StratifiedKFold”。当我对多标签问题进行交叉验证时,它失败了。 是否可以对多标签问题 scikit-learn 执行交叉验证?

我认为问题出在类标签列表的元组中,例如 ([1], [2], [2], [1], [1,2], [3], [1,2,3] ..)

我认为此错误的代码如下:

n_samples = X.shape[0]
Y_list = [value for value in Y.T]
print 'Y_list[0].shape:', Y_list[0].shape, 'len(Y_list):', len(Y_list)
cv = cross_validation.StratifiedKFold(Y_list, 3)

【问题讨论】:

【参考方案1】:

如果你打算用 scikit-learn 解决多标签任务,建议先 使用MultilabelBinarizer 将您的输出转换为标签二进制指示符。

分层 k 折叠不支持多标签格式,因为它需要平衡每个标签的正数比例。相反,您可以使用K-folds 或shuffle split 交叉验证策略。

【讨论】:

谢谢,这正是我需要的。 @Arnaud Joly 如果我有多标签分类和不平衡的数据怎么办?那么我应该怎么做才能在每个标签中进行平衡训练和测试集呢?谢谢:) @sariii 使用 scikit-multilearn,其中包含多标签分层(在包中查找迭代分层)

以上是关于多标签分类的交叉验证错误的主要内容,如果未能解决你的问题,请参考以下文章

具有高度不平衡的多标签分类中的损失曲线

分类交叉熵和标签编码

GridSearch用于Scikit-learn中的多标签分类

多标签分类损失函数

使用 SKlearn 进行多标签分类 - 如何使用验证集?

多标签分类中的损失函数与评价指标