如何在不同的数据帧上生成交叉验证以进行监督分类?
Posted
技术标签:
【中文标题】如何在不同的数据帧上生成交叉验证以进行监督分类?【英文标题】:How to generate cross validation over different dataframes for supervised classification? 【发布时间】:2020-06-03 18:37:58 【问题描述】:想象一下,我有 4 个行长不同但列数相同的数据框,如下所示:df1(200 行,4 列)、df2(100, 4)、df3(300, 4) 和 df4(250, 4) .
我想在这些数据帧之间进行监督分类(始终使用 3 进行训练,使用 1 进行测试/验证)并发现哪种组合可以让我获得更好的准确度分数。 这是一个更大数据量的例子,我想通过交叉验证来自动化它。
我认为我可以尝试为每个具有特定名称的数据框创建一个新列,然后将它们全部连接起来。然后,也许,创建一个掩码,通过这些新列区分训练集和测试集。但是我仍然不知道如何在它们之间进行这种交叉验证。
数据框是这样的:
concatenated_dfs:
feat1 feat2 feat3 feat4 name
0 4 6 57 78 df1
1 1 2 50 78 df1
2 1 1 57 78 df1
. . . . . .
. . . . . .
. . . . . .
849 3 10 50 80 df4
谁能告诉我如何用一些代码来做到这一点?如果需要,您可以使用任何 scikit-learn 分类算法。谢谢!
【问题讨论】:
【参考方案1】:您可以使用 scikit learn 的 cross_val_score
和自定义迭代来生成数据中训练-测试拆分的索引。这是一个例子:
# Setup - creating fake data to match your description
df = pd.DataFrame(data="name":[x for l in [[f"dfi"]*c for i, c in enumerate(counts, 1)] for x in l])
for i in range(1, 5):
df[f"feati"] = np.random.randn(len(df))
X = df[[c for c in df.columns if c != "name"]]
y = np.random.randint(0, 2, len(df))
# Iterable to generate the training-test splits:
indices = list()
for name in df["name"].unique():
train = df.loc[df["name"] != name].index
test = df.loc[df["name"] == name].index
indices.append((train, test))
# Example model - logistic regression
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
# Using cross-val score with the custom indices:
from sklearn.model_selection import cross_val_score
scores = cross_val_score(model, X, y, cv=indices)
【讨论】:
太好了,托比。谢谢您的帮助。但是我还有一个问题:在我运行代码并得到分数之后,我怎么知道哪个组合代表了那个分数? 分数的顺序与for name in df["name"].unique()
行中名称的迭代顺序相同。因此,如果名字是“df1”,那么第一个分数将使用“df1”作为测试集,使用 dfs 2/3/4 作为训练集。如果您想确定而不依赖于unique
对数据的排序方式,可以更改此行以明确说明顺序。
其实我还有一个问题。这些分数并没有给我预测的真实准确度分数......我的意思是,它们与我使用“accuracy_score”检查一个数据帧中的预测而使用其他四个作为训练时不同。所以如果我想要这些准确度分数,我需要另一种方法,对吧?
不完全确定您的意思,但请查看cross_val_score
的用户指南以了解如何指定要使用的评分指标:scikit-learn.org/stable/modules/generated/… 无论您选择哪个评分指标,它都会被计算为性能在交叉验证的每次迭代中应用于测试集的训练集上训练的模型。
没关系。我只需要重置连接数据帧的索引,因为该方法可能会混淆不同 dfs 的索引,因为这些数据帧在某些部分具有相同的索引。现在它运作良好。谢谢:)以上是关于如何在不同的数据帧上生成交叉验证以进行监督分类?的主要内容,如果未能解决你的问题,请参考以下文章