为啥sklearn kNN分类器运行得这么快,而我的训练样本和测试样本的数量很大
Posted
技术标签:
【中文标题】为啥sklearn kNN分类器运行得这么快,而我的训练样本和测试样本的数量很大【英文标题】:Why sklearn's kNN classifer runs so fast while the number of my training samples and test samples are large为什么sklearn kNN分类器运行得这么快,而我的训练样本和测试样本的数量很大 【发布时间】:2021-04-25 15:02:02 【问题描述】:据我了解,对于每个测试样本,kNN分类器算法计算当前测试样本与所有训练样本的距离,并选择一定数量的最近邻,并确定测试样本的标签,然后下一步测试样品将完成。
我的代码类似于以下超链接中的示例 kNN 分类器代码,非常简单:
https://tutorialspoint.dev/computer-science/machine-learning/multiclass-classification-using-scikit-learn
我的训练样本数是8000,测试样本数是1500,样本维度是12。
我跑sklearn kNN分类器代码时,只用了2秒,准确率不错。
我怀疑sklearn kNN算法所花费的时间,所以我写了一个简单的代码来计算测试样本和所有训练样本之间的距离,发现这是一个耗时的过程,甚至不包括排序算法。距离计算代码如下:
for i in range(X_test.shape[0]):
for j in range(X_train.shape[0]):
## calculate distances between a test sample and all train samples
Distance[j,0] = (X_test.iloc[i,0]-X_train.iloc[j,0])*(X_test.iloc[i,0]-X_train.iloc[j,0]) + \
(X_test.iloc[i,1]-X_train.iloc[j,1])*(X_test.iloc[i,1]-X_train.iloc[j,1]) + \
(X_test.iloc[i,2]-X_train.iloc[j,2])*(X_test.iloc[i,2]-X_train.iloc[j,2]) + \
(X_test.iloc[i,3]-X_train.iloc[j,3])*(X_test.iloc[i,3]-X_train.iloc[j,3]) + \
(X_test.iloc[i,4]-X_train.iloc[j,4])*(X_test.iloc[i,4]-X_train.iloc[j,4]) + \
(X_test.iloc[i,5]-X_train.iloc[j,5])*(X_test.iloc[i,5]-X_train.iloc[j,5]) + \
(X_test.iloc[i,6]-X_train.iloc[j,6])*(X_test.iloc[i,6]-X_train.iloc[j,6]) + \
(X_test.iloc[i,7]-X_train.iloc[j,7])*(X_test.iloc[i,7]-X_train.iloc[j,7]) + \
(X_test.iloc[i,8]-X_train.iloc[j,8])*(X_test.iloc[i,8]-X_train.iloc[j,8]) + \
(X_test.iloc[i,9]-X_train.iloc[j,9])*(X_test.iloc[i,9]-X_train.iloc[j,9]) + \
(X_test.iloc[i,10]-X_train.iloc[j,10])*(X_test.iloc[i,10]-X_train.iloc[j,10]) + \
(X_test.iloc[i,11]-X_train.iloc[j,11])*(X_test.iloc[i,11]-X_train.iloc[j,11])
我不确定 sklearn 是否使用整个训练数据集来计算 k 个最近邻。如果有,sklearn 使用的是什么优化算法?
提前谢谢你。
【问题讨论】:
【参考方案1】:你说得对,最近邻搜索确实很耗时。如果你天真地做,你的运行时间是 O(n^2)。好消息是,sklearn 使用了一些巧妙的算法来规避计算所有距离。
查看docs,您会发现其中一个参数是用于最近邻搜索的algorithm
,例如BallTree。这些算法大大加快了计算速度。
另一方面,您的代码效率有点低。您可以这样做,而不是手动计算每个维度:
((X_test.iloc[i,:] - X_train.iloc[j,:]) ** 2).sum()
这利用了 pandas 的矢量化函数,使其速度更快。
【讨论】:
hi tilman151,我分别尝试了“algorithm=ball_tree, kd_tree and brute”,运行时间没有显着差异。我用你的代码来计算距离,它比我的代码快一点。我还有一个问题:给定一个测试样本,kNN 算法需要计算测试样本与 8000 个训练样本之间的距离,然后选择 k 个最近邻,即使使用了 balltree,这也是很耗时的。我仍然怀疑 sklearn 没有使用整个训练样本。 使用 BallTree,它不需要在预测时计算所有距离。它在训练时构建树,然后只检查它必须去的树的哪个叶子。叶只包含少量靠近测试样本的训练样本,因此 sklearn 只计算它们的距离。 谢谢解释,看来我需要研究一下balltree算法了。以上是关于为啥sklearn kNN分类器运行得这么快,而我的训练样本和测试样本的数量很大的主要内容,如果未能解决你的问题,请参考以下文章
机器学习 sklearn 监督学习 分类算法 KNN K-NearestNeighbor