如何从 scikit-learn 玩具数据集中预测数据
Posted
技术标签:
【中文标题】如何从 scikit-learn 玩具数据集中预测数据【英文标题】:How to predict data from scikit-learn toy dataset 【发布时间】:2021-07-21 20:46:25 【问题描述】:我正在学习机器学习,并且正在尝试分析 scikit 糖尿病玩具数据库。在这种情况下,我想将默认的 Bunch 对象更改为 pandas DataFrame 对象。我尝试使用参数 as_frame=True,它确实将对象类型更改为 DataFrame。
所以在那之后,我对数据进行了训练,当我尝试绘制它时,问题就来了:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn import datasets, linear_model
from sklearn.model_selection import train_test_split
dataset = datasets.load_diabetes(as_frame=True)
X = dataset.data
y = dataset.target
y = y.to_frame()
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=42)
regressor = linear_model.LinearRegression()
regressor.fit(X_train, y_train)
plt.scatter(X_train, y_train, color='blue')
plt.plot(X_train, regressor.predict(X_test), color='red')
问题是当我尝试使用 matplotlib 绘制它时,因为 as_frame=True 返回 (data, target),其中数据是 DataFrame 对象,目标是 Series。
Traceback (most recent call last):
File "C:/Users/Kelvin/OneDrive/Documents/analytics/diabetes-sklearn/test.py", line 19, in <module>
plt.scatter(X_train, y_train, color='blue')
File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\pyplot.py", line 3037, in scatter
__ret = gca().scatter(
File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\__init__.py", line 1352, in inner
return func(ax, *map(sanitize_sequence, args), **kwargs)
File "C:\Users\Kelvin\OneDrive\Desktop\analytics\lib\site-packages\matplotlib\axes\_axes.py", line 4478, in scatter
raise ValueError("x and y must be the same size")
ValueError: x and y must be the same size
所以,我的问题是,是否有办法可以将整个数据更改为 DataFrame,就像我们使用 pd.read_csv() 获取数据的方式一样?
【问题讨论】:
您不能绘制 X_train 与 y_train,因为 X_train 有多个列。如果它只是一列,你可以拥有 【参考方案1】:这已经是一个数据框,您会遇到错误,因为您正在用 y_train 绘制 X_train 并且 X_train 有多个列。
但如果您希望数据集保存在 csv 文件中,您可以使用此代码。
X.to_csv('train_data.csv')
这会将数据集保存到工作目录中的 csv 文件中。现在您可以在train_data.csv
上使用pd.read_csv
。
【讨论】:
以上是关于如何从 scikit-learn 玩具数据集中预测数据的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 scikit-learn for python 分析和预测(机器学习)时间序列数据集
机器学习笔记:常用数据集之scikit-learn内置玩具数据集
如何从 scikit-learn 中与 predict_proba 一起使用的 cross_val_predict 获取类标签