如何从 scikit-learn 玩具数据集中预测数据

Posted

技术标签:

【中文标题】如何从 scikit-learn 玩具数据集中预测数据【英文标题】:How to predict data from scikit-learn toy dataset 【发布时间】:2021-07-21 20:46:25 【问题描述】:

我正在学习机器学习,并且正在尝试分析 scikit 糖尿病玩具数据库。在这种情况下,我想将默认的 Bunch 对象更改为 pandas DataFrame 对象。我尝试使用参数 as_frame=True,它确实将对象类型更改为 DataFrame。

所以在那之后,我对数据进行了训练,当我尝试绘制它时,问题就来了:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn import datasets, linear_model
from sklearn.model_selection import train_test_split

dataset = datasets.load_diabetes(as_frame=True)

X = dataset.data
y = dataset.target

y = y.to_frame()

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=42)

regressor = linear_model.LinearRegression()
regressor.fit(X_train, y_train)

plt.scatter(X_train, y_train, color='blue')
plt.plot(X_train, regressor.predict(X_test), color='red')

问题是当我尝试使用 matplotlib 绘制它时,因为 as_frame=True 返回 (data, target),其中数据是 DataFrame 对象,目标是 Series。

Traceback (most recent call last):
  File "C:/Users/Kelvin/OneDrive/Documents/analytics/diabetes-sklearn/test.py", line 19, in <module>
    plt.scatter(X_train, y_train, color='blue')
  File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\pyplot.py", line 3037, in scatter
    __ret = gca().scatter(
  File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\__init__.py", line 1352, in inner
    return func(ax, *map(sanitize_sequence, args), **kwargs)
  File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\axes\_axes.py", line 4478, in scatter
    raise ValueError("x and y must be the same size")
ValueError: x and y must be the same size

所以,我的问题是,是否有办法可以将整个数据更改为 DataFrame,就像我们使用 pd.read_csv() 获取数据的方式一样?

【问题讨论】:

您不能绘制 X_train 与 y_train,因为 X_train 有多个列。如果它只是一列,你可以拥有 【参考方案1】:

这已经是一个数据框,您会遇到错误,因为您正在用 y_train 绘制 X_train 并且 X_train 有多个列。

但如果您希望数据集保存在 csv 文件中,您可以使用此代码。

X.to_csv('train_data.csv')

这会将数据集保存到工作目录中的 csv 文件中。现在您可以在train_data.csv 上使用pd.read_csv

【讨论】:

以上是关于如何从 scikit-learn 玩具数据集中预测数据的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 scikit-learn for python 分析和预测(机器学习)时间序列数据集

用 scikit-learn 拟合一维数据来预测线

机器学习笔记:常用数据集之scikit-learn内置玩具数据集

如何从 scikit-learn 中与 predict_proba 一起使用的 cross_val_predict 获取类标签

如何使用 scikit-learn 的 LinearRegression() 捕获时间序列数据的趋势以进行预测

scikit-learn:如何缩减“y”预测结果