Sklearn 线性回归 - “IndexError:元组索引超出范围”
Posted
技术标签:
【中文标题】Sklearn 线性回归 - “IndexError:元组索引超出范围”【英文标题】:Sklearn Linear Regression - "IndexError: tuple index out of range" 【发布时间】:2015-01-22 07:06:39 【问题描述】:我有一个“.dat”文件,其中保存了 X 和 Y 的值(所以是一个元组 (n,2),其中 n 是行数)。
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as interp
from sklearn import linear_model
in_file = open(path,"r")
text = np.loadtxt(in_file)
in_file.close()
x = np.array(text[:,0])
y = np.array(text[:,1])
我为linear_model.LinearRegression()
创建了一个实例,但是当我调用.fit(x,y)
方法时,我得到了
IndexError: 元组索引超出范围
regr = linear_model.LinearRegression()
regr.fit(x,y)
我做错了什么?
【问题讨论】:
对不起,我完全误读了你的问题:(我已经删除了答案,如果我能得到修复,那么我将取消删除编辑后的答案。但是你能提供更多信息吗?比如你的完整代码? 这是你需要的代码,没有其他重要的了。 真的吗?linear_model
是什么?你是怎么得到它的?
现在就这些了,感谢您的帮助。
x 和 Y 的长度是否相同?
【参考方案1】:
线性回归预计X
是一个二维数组,内部需要X.shape[1]
来初始化np.ones
数组。因此,将 X
转换为 nx1 数组 就可以了。所以,替换:
regr.fit(x,y)
作者:
regr.fit(x[:,np.newaxis],y)
这将解决问题。演示:
>>> from sklearn import datasets
>>> from sklearn import linear_model
>>> clf = linear_model.LinearRegression()
>>> iris=datasets.load_iris()
>>> X=iris.data[:,3]
>>> Y=iris.target
>>> clf.fit(X,Y) # This will throw an error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 363, in fit
X, y, self.fit_intercept, self.normalize, self.copy_X)
File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 103, in center_data
X_std = np.ones(X.shape[1])
IndexError: tuple index out of range
>>> clf.fit(X[:,np.newaxis],Y) # This will work properly
LinearRegression(copy_X=True, fit_intercept=True, normalize=False)
要绘制回归线,请使用以下代码:
>>> from matplotlib import pyplot as plt
>>> plt.scatter(X, Y, color='red')
<matplotlib.collections.PathCollection object at 0x7f76640e97d0>
>>> plt.plot(X, clf.predict(X[:,np.newaxis]), color='blue')
<matplotlib.lines.Line2D object at 0x7f7663f9eb90>
>>> plt.show()
【讨论】:
非常感谢您的帮助!另一个问题:现在我只从线性回归中得到一个系数是否正常?如何绘制它的线? @JackLametta,这绝对正常。这些系数用于在给定 Y 值的情况下预测 X 值。我已将代码上传到情节线。以上是关于Sklearn 线性回归 - “IndexError:元组索引超出范围”的主要内容,如果未能解决你的问题,请参考以下文章