带有索引的 Scikit-learn train_test_split
Posted
技术标签:
【中文标题】带有索引的 Scikit-learn train_test_split【英文标题】:Scikit-learn train_test_split with indices 【发布时间】:2015-10-09 20:26:15 【问题描述】:使用train_test_split()时如何获取数据的原始索引?
我拥有的是以下
from sklearn.cross_validation import train_test_split
import numpy as np
data = np.reshape(np.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels
x1, x2, y1, y2 = train_test_split(data, labels, size=0.2)
但这并没有给出原始数据的索引。
一种解决方法是将索引添加到数据中(例如data = [(i, d) for i, d in enumerate(data)]
),然后将它们传递到train_test_split
中,然后再次展开。
有没有更清洁的解决方案?
【问题讨论】:
另请注意sklearn.model_selection.ShuffleSplit 和sklearn.model_selection.StratifiedShuffleSplit。 【参考方案1】:您可以像 Julien 所说的那样使用 pandas 数据帧或系列,但如果您想将自己限制为 numpy,您可以传递一个额外的索引数组:
from sklearn.model_selection import train_test_split
import numpy as np
n_samples, n_features, n_classes = 10, 2, 2
data = np.random.randn(n_samples, n_features) # 10 training examples
labels = np.random.randint(n_classes, size=n_samples) # 10 labels
indices = np.arange(n_samples)
(
data_train,
data_test,
labels_train,
labels_test,
indices_train,
indices_test,
) = train_test_split(data, labels, indices, test_size=0.2)
【讨论】:
从 NumPt v1.1 开始,第 3 行应该是data = np.reshape(np.random.randn(20),(10,2))
;最后一行应该是... train_test_split(data, labels, indices, test_size=0.2)
其实这应该是接受的响应,因为它没有使用任何额外的包,而是使用 sklearn。它可以更好地控制 pandas 的情况。
@ogrisel 你好我有一个类似的问题,你能检查一下***.com/questions/48734942/…
嘿,你第三行的 n_class 是什么?班级人数,班级是什么意思??提前致谢。
这是在labels
变量中生成随机分类标签的目标标签类数。【参考方案2】:
如果您使用的是 pandas,您可以通过调用您希望模拟的任何数组的 .index 来访问索引。 train_test_split 将 pandas 索引传递给新的数据帧。
在您的代码中,您只需使用
x1.index
并且返回的数组是与 x 中原始位置相关的索引。
【讨论】:
【参考方案3】:Scikit learn 非常适合 Pandas,所以我建议你使用它。这是一个例子:
In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
data = np.reshape(np.random.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels
In [2]: # Giving columns in X a name
X = pd.DataFrame(data, columns=['Column_1', 'Column_2'])
y = pd.Series(labels)
In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2,
random_state=0)
In [4]: X_test
Out[4]:
Column_1 Column_2
2 -1.39 -1.86
8 0.48 -0.81
4 -0.10 -1.83
In [5]: y_test
Out[5]:
2 1
8 1
4 1
dtype: int32
您可以直接调用 DataFrame/Series 上的任何 scikit 函数,它会起作用。
假设你想做一个 LogisticRegression,下面是你如何以一种很好的方式检索系数:
In [6]:
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model = model.fit(X_train, y_train)
# Retrieve coefficients: index is the feature name (['Column_1', 'Column_2'] here)
df_coefs = pd.DataFrame(model.coef_[0], index=X.columns, columns = ['Coefficient'])
df_coefs
Out[6]:
Coefficient
Column_1 0.076987
Column_2 -0.352463
【讨论】:
另外,您的问题中的代码似乎有问题,或者您可能正在使用已弃用的 scikit 和 numpy 版本(我的 np.randn 不存在,test_size
是在train_test_split中使用而不是size
)
我编辑了我的答案,以展示如何使用来自 pandas 数据帧的特征名称检索系数。将来可能会为您节省一点时间。
嗨@Julien Marrec 我尝试应用您的解决方案,但没有成功,您可以在这里查看***.com/questions/48734942/…
如果我没有先创建索引就已经拆分了数据怎么办?
这如何回答原始问题?看来您只是在创建一个新数据框并将原始数据框的索引应用于新创建的数组。当您使用训练测试拆分随机化数据时,您正在对行进行洗牌,并且简单地将来自先前数据帧的索引应用于洗牌数据不允许您准确地访问原始数据中的索引。我错过了什么吗?【参考方案4】:
这是最简单的解决方案(Jibwa 在另一个答案中让它看起来很复杂),无需自己生成索引 - 只需使用 ShuffleSplit 对象生成 1 个拆分。
import numpy as np
from sklearn.model_selection import ShuffleSplit # or StratifiedShuffleSplit
sss = ShuffleSplit(n_splits=1, test_size=0.1)
data_size = 100
X = np.reshape(np.random.rand(data_size*2),(data_size,2))
y = np.random.randint(2, size=data_size)
sss.get_n_splits(X, y)
train_index, test_index = next(sss.split(X, y))
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
【讨论】:
【参考方案5】:docs 提到的 train_test_split 只是 shuffle split 之上的一个便利功能。
我只是重新排列了他们的一些代码来制作我自己的示例。请注意,实际的解决方案是中间的代码块。剩下的就是导入,并设置一个可运行的示例。
from sklearn.model_selection import ShuffleSplit
from sklearn.utils import safe_indexing, indexable
from itertools import chain
import numpy as np
X = np.reshape(np.random.randn(20),(10,2)) # 10 training examples
y = np.random.randint(2, size=10) # 10 labels
seed = 1
cv = ShuffleSplit(random_state=seed, test_size=0.25)
arrays = indexable(X, y)
train, test = next(cv.split(X=X))
iterator = list(chain.from_iterable((
safe_indexing(a, train),
safe_indexing(a, test),
train,
test
) for a in arrays)
)
X_train, X_test, train_is, test_is, y_train, y_test, _, _ = iterator
print(X)
print(train_is)
print(X_train)
现在我有了实际的索引:train_is, test_is
【讨论】:
以上是关于带有索引的 Scikit-learn train_test_split的主要内容,如果未能解决你的问题,请参考以下文章
Scikit-learn cross val得分:数组的索引太多了
带有索引的 scikit-learn StratifiedShuffleSplit KeyError
K-means 仅使用带有 scikit-learn 的特定数据框列