使用 SVM 模型和 scikit-learn 进行预测的 AttributeError

Posted

技术标签:

【中文标题】使用 SVM 模型和 scikit-learn 进行预测的 AttributeError【英文标题】:AttributeError with prediction using SVM model and scikit-learn 【发布时间】:2020-11-05 05:33:29 【问题描述】:

我正在使用经过训练的 SVM 模型来预测新数据样本的类别标签。数据的处理方式与训练数据完全相同(相同的特征)。尝试预测时出现 AttributeError:

data_to_predict = feature_row
prediction = trained_model.predict(data_to_predict)

data_to_predict 是一个单行的 pandas DataFrame。我收到此错误:

AttributeError: 'SVC' object has no attribute 'probA_'

我认为(也许是错误的)问题在于数据的形式,因为当我打印 feature_row 时它看起来是正确的。到目前为止,这是我尝试过的:

    确保使用probability=True 训练模型 尝试将数据重塑为[feature_row],错误更改为KeyError: 0 尝试将数据更改为数组np.array(feature_row),返回原始错误 尝试重塑数组np.array(feature_row).reshape((-1, 1)),错误为ValueError: Number of features of the input must be equal to or greater than that of the fitted transformer. Transformer n_features is 31 and input n_features is 1. 尝试以相反的方式重塑 np.array(feature_row).reshape((1, -1)),回到原来的错误 终于,尝试了[np.array(feature_row).reshape(-1, 1)],得到了ValueError: Found array with dim 3. Estimator expected <= 2.

额外信息:我不确定它是否对解决此问题有任何影响,但该模型涉及:

一列上带有 TfidfVectorizer 的 ColumnTransformer(带有 remainder='passthrough') 后跟 SVC 步骤 SVC(probability=True, class_weight='balanced') scikit-learn 版本 0.22.1

编辑:以防以后有人遇到此线程。在帮助下 Victor Luu,我们设法将问题与泡菜文件本身隔离开来。经过进一步探索,我发现我正在两个不同的环境中训练和测试模型,这些环境的 scikit-learn 版本(0.23.1 和 0.22.1)略有不同,一旦我确定它们匹配,错误就消失了。

【问题讨论】:

您好,您的模型安装了吗?如果还没有,需要先安装它 我已经装好了,它已经被训练过了,我正在从一个泡菜文件中加载它 【参考方案1】:

根据我的经验,从 pickle 文件加载可能会导致这样的问题。我的建议是在训练完模型后立即进行预测,不要使用保存的模型。

【讨论】:

很遗憾,我做不到,它是 ML 管道的一部分,训练好的模型需要存储在服务器上并定期调用以进行预测 我的意思是你先试试看它是否有效。然后您可以确定问题是由于使用pickle文件中的模型造成的,并且可以继续从那里找到解决方案。另外,检查你是否正确保存了模型,见model persistence 我可以尝试一下,当然,重新训练模型需要时间,所以一旦我完成了这个,我会回到这个线程。 另外,基于sklearn中的模型持久化,最好使用joblib然后pickle转储模型,你应该尝试joblib 不是很相关,但仅供参考 SVC 没有 predict_proba()。另外,只需先在小样本上重新训练模型,然后调用 predict 来查看 bug 是否已修复,这样会更快。

以上是关于使用 SVM 模型和 scikit-learn 进行预测的 AttributeError的主要内容,如果未能解决你的问题,请参考以下文章

将 scikit-learn SVM 模型转换为 LibSVM

Python 元组和列表操作(作为 scikit-learn 中 SVM 模型的输入)

使用 scikit-learn 重新拟合 SVM

将经过训练的 SVM 从 scikit-learn 导入到 OpenCV

使用 scikit-learn 线性 SVM 提取决策边界

使用 scikit-learn python 的线性 SVM 时出现 ValueError