Scikit Learn - 通过加载 CSV 识别目标

Posted

技术标签:

【中文标题】Scikit Learn - 通过加载 CSV 识别目标【英文标题】:Scikit Learn - Identifying target from loading a CSV 【发布时间】:2015-08-31 04:13:20 【问题描述】:

我正在使用 Numpy 加载 csv 作为数据集,以在 Python 中创建决策树模型。使用下面的提取将第 0-7 列放在 X 中,最后一列作为 Y 中的目标。

#load and set data
data = np.loadtxt("data/tmp.csv", delimiter=",")
X = data[:,0:7] #identify columns as data sets
Y = data[:,8] #identfy last column as target

#create model
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)

我想知道的是是否可以在任何列中使用分类器。例如,如果它在第四列中,以下代码是否仍能正确拟合模型,或者在预测时会产生错误?

#load and set data
data = np.loadtxt("data/tmp.csv", delimiter=",")
X = data[:,0:8] #identify columns as data sets
Y = data[:,3] #identfy fourth column as target

#create model
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)

【问题讨论】:

我不这么认为。它会知道输出总是等于第四个特征:D X = data[:,0:8]Y = data[:,3] 表示您将在特征中包含目标! 【参考方案1】:

如果您有 >4 列,并且第 4 列是目标,而其他列是特征,这里是加载它们的一种方法(在众多方法中):

# load data

X = np.hstack([data[:, :3], data[:, 5:]]) # features
Y = data[:,4] # target

# process X & Y

(感谢@omerbp 提醒我的迟到hstack接受元组/列表,而不是裸参数!)

【讨论】:

嗨,我相信你的意思是X = np.hstack(data[:, :3], data[:, 5:]),但X = np.hstack([data[:, :3], data[:, 5:]])...虽然很好的答案,+1。我在回答时间比较中添加了您建议的方法和this one 谢谢@omerbp,当我整合你的修复程序时,我太忙了,没有礼貌。编辑以在信用到期时给予信用。 哈哈,你已经走得太远了 :) 在这种情况下,赞成就足够了 :) (正如我所说,我赞成你的,很好的答案)【参考方案2】:

首先,正如@mescalinum 在对该问题的评论中所建议的那样,考虑一下这种情况:

.... 4th_feature ...    label
....      1      ...      1
....      0      ...      0
....      1      ...      1
............................

在此示例中,分类器(任何分类器,尤其是 DecisionTreeClassifier 除外)将了解到第 4 个特征可以最好地预测标签,因为第 4 个特征标签。不幸的是,这个问题经常发生(我的意思是偶然)。

其次,如果您想将第 4 个特征作为输入标签,您可以 swap 列:

arr[:,[frm, to]] = arr[:,[to, frm]]

@Ahemed Fasih's answer 也可以做到这一点,但是它慢了大约 10 倍:

import timeit


setup_code = """
import numpy as np
i, j = 400000, 200
my_array = np.arange(i*j).reshape(i, j)
"""

swap_cols = """
def swap_cols(arr, frm, to):
    arr[:,[frm, to]] = arr[:,[to, frm]]
"""

stack ="np.hstack([my_array[:, :3], my_array[:, 5:]])"
swap ="swap_cols(my_array, 4, 8)"

print "hstack - total time:", min(timeit.repeat(stmt=stack,setup=setup_code,number=20,repeat=3))
#hstack - total time: 3.29988478635
print "swap - total time:", min(timeit.repeat(stmt=swap,setup=setup_code+swap_cols,number=20,repeat=3))
#swap - total time: 0.372791106328

【讨论】:

以上是关于Scikit Learn - 通过加载 CSV 识别目标的主要内容,如果未能解决你的问题,请参考以下文章

Numpy/scipy 加载巨大的稀疏矩阵以在 scikit-learn 中使用

如何将 csv 数据文件导入 scikit-learn?

如何在 Scikit-Learn 中重用 LabelBinarizer 进行输入预测

使用 Pandas 为 Scikit-Learn 准备 CSV 文件数据?

Scikit-Learn:如何处理不可排序的类型错误?

机器学习实验scikit-learn的主要模块和基本使用