使用 LightGBM 进行多类分类

Posted

技术标签:

【中文标题】使用 LightGBM 进行多类分类【英文标题】:Multiclass Classification with LightGBM 【发布时间】:2018-05-02 09:17:12 【问题描述】:

我正在尝试在 Python 中使用 LightGBM 为多类分类问题(3 类)建模分类器。我使用了以下参数。

params = 'task': 'train',
    'boosting_type': 'gbdt',
    'objective': 'multiclass',
    'num_class':3,
    'metric': 'multi_logloss',
    'learning_rate': 0.002296,
    'max_depth': 7,
    'num_leaves': 17,
    'feature_fraction': 0.4,
    'bagging_fraction': 0.6,
    'bagging_freq': 17

数据集的所有分类特征都是用LabelEncoder 编码的标签。我在运行cveartly_stopping 后训练了模型,如下所示。

lgb_cv = lgbm.cv(params, d_train, num_boost_round=10000, nfold=3, shuffle=True, stratified=True, verbose_eval=20, early_stopping_rounds=100)

nround = lgb_cv['multi_logloss-mean'].index(np.min(lgb_cv['multi_logloss-mean']))
print(nround)

model = lgbm.train(params, d_train, num_boost_round=nround)

训练后,我用这样的模型进行预测,

preds = model.predict(test)
print(preds)             

我得到一个嵌套数组作为这样的输出。

[[  7.93856847e-06   9.99989550e-01   2.51164967e-06]
 [  7.26332978e-01   1.65316511e-05   2.73650491e-01]
 [  7.28564308e-01   8.36756769e-06   2.71427325e-01]
 ..., 
 [  7.26892634e-01   1.26915179e-05   2.73094674e-01]
 [  5.93217601e-01   2.07172044e-04   4.06575227e-01]
 [  5.91722491e-05   9.99883828e-01   5.69994435e-05]]

preds 中的每个列表都代表我使用 np.argmax() 来查找此类的概率。

predictions = []

for x in preds:
    predictions.append(np.argmax(x))

在分析预测时,我发现我的预测只包含 2 个类 - 0 和 1。第 2 类是训练集中的第二大类,但在预测中找不到它。在评估结果时,它给出了大约78% 的准确性。

那么,为什么我的模型没有预测任何情况下的第 2 类。?我使用的参数有什么问题吗?

这不是解释模型做出的预测的正确方法吗?我应该对参数进行任何更改吗??

【问题讨论】:

我不知道这段代码到底有什么问题,但我认为您的问题似乎是二进制分类,但您使用多类分类指标来提高准确性。我宁愿建议您使用 binary_logloss 来解决您的问题。你可以找到更多关于同一here 我的目标中有 3 个类。我已经检查过了 【参考方案1】:

解决办法是:

best_preds_svm = [np.argmax(line) for line in preds]

然后你可以打印出结果最合理的类。

【讨论】:

【参考方案2】:

从您提供的输出来看,预测似乎没有错。

该模型产生三个概率,正如您所展示的,仅从您提供的第一个输出 [7.93856847e-06 9.99989550e-01 2.51164967e-06] 类 2 具有更高的概率,所以我在这里看不到问题。

0 级是第一级,1 级实际上是 2 级第二级,2 是第三级。所以我想没有错。

【讨论】:

该模型不会预测任何输入样本的第 3 类,即使是在其训练的样本上也是如此。!!【参考方案3】:

尝试通过交换类 0 和 2 进行故障排除,然后重新运行训练和预测过程。

如果新预测仅包含第 1 类和第 2 类(很可能根据您提供的数据):

分类器可能没有学过第三类;也许它的特征与较大类的特征重叠,并且分类器默认为较大的类以最小化目标函数。尝试提供一个平衡的训练集(每个类的样本数量相同)并重试。

如果新预测确实包含所有 3 个类:

您的代码某处出现问题。需要更多信息来确定究竟出了什么问题。

希望这会有所帮助。

【讨论】:

【参考方案4】:
import pandas as pd

pd.DataFrame(preds).apply(lambda x: np.argmax(x), axis=1)

【讨论】:

以上是关于使用 LightGBM 进行多类分类的主要内容,如果未能解决你的问题,请参考以下文章

LightGBM-分类指标不能处理二进制和连续目标的混合

使用 Keras 进行分类:预测和多类

使用 Keras 稀疏分类交叉熵进行像素级多类分类

使用 Python API 进行逻辑回归多类分类

使用 scikit learn 训练逻辑回归进行多类分类

使用 Apache Spark 决策树分类器进行多类分类时出错