特征不匹配:通过 scikit-learn 管道进行预测

Posted

技术标签:

【中文标题】特征不匹配:通过 scikit-learn 管道进行预测【英文标题】:Feature mismatch: Prediction through scikit-learn Pipeline 【发布时间】:2021-08-24 17:57:47 【问题描述】:

我在一个名为 build.py 的文件中实现了以下 scikit-learn 管道,后来,成功地对其进行了腌制。

preprocessor = ColumnTransformer(transformers=[
        ('target', TargetEncoder(), COL_TO_TARGET),
        ('one_hot', OneHotEncoder(drop_invariant=False, handle_missing='value',
              handle_unknown='value', return_df=True, use_cat_names=True,
              verbose=0), COL_TO_DUM),
        ('construction', OrdinalEncoder(mapping=mapping),['ConstructionPeriod'])
      ], remainder='passthrough')

test_pipeline = Pipeline(steps=[
            ('preprocessor', preprocessor),
            ('std_scale', StandardScaler()),
            ('XGB_model', 
                xgb.XGBRegressor(
                    booster = 'gbtree', colsample_bylevel=0.75,colsample_bytree=0.75,
                    max_depth = 20,grow_policy = 'depthwise',learning_rate = 0.1
                 )
             )
        ])
test_pipeline.fit(X_train, y_train)

import pickle
pickle.dump(open('final_pipeline.pkl','wb'), test_pipeline)

然后在另一个文件 app.py 中读取腌制管道,该文件接受用户数据以通过未腌制管道进行预测。

pipeline = pickle.load(open('final_pipeline.pkl', 'rb'))

# data is the coming from the user via frontend
input_df = pd.DataFrame(data.dict(), index=[0])

# using the pipeline to predict 
prediction = pipeline.predict(input_df)

我遇到的挑战是未腌制的管道期望传入的测试数据具有类似于用于训练管道 (X_train) 的列结构。

为了解决这个问题,我需要对传入的测试数据列进行排序以匹配 X_train 的列。

肮脏的解决方案,将 X_train 列名称导出到文件中,然后在 app.py 中读取它以重新排列传入测试数据的列。

关于如何以python方式解决此问题的任何建议?

【问题讨论】:

【参考方案1】:

您的列顺序不应该很重要,但如果是,那么为什么不在管道中对列进行排序,然后在其他代码文件中对其进行排序。这样您就不必进行任何本地存储。

df = df.reindex(sorted(df.columns), axis=1)

【讨论】:

感谢您的回复。我查了一下,发现很多文章都引用了类似的问题。特征顺序很重要的原因是管道将特征空间转换为矩阵,因此标签变得过时了。你的想法对我有用。

以上是关于特征不匹配:通过 scikit-learn 管道进行预测的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Scikit-learn 的管道中创建我们的自定义特征提取器函数并将其与 countvectorizer 一起使用

使用 scikit-learn 管道与手动操作时的不同分数

如何将功能管道从 scikit-learn V0.21 移植到 V0.24

使用 Scikit-Learn 在管道中包含预测器

当最后一个估计器不是转换器时,如何使用 scikit-learn 管道进行转换?

scikit-learn,线性回归中的分类(但数字)特征