k近邻算法的Python实现
Posted pkuimyy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了k近邻算法的Python实现相关的知识,希望对你有一定的参考价值。
k近邻算法的Python实现
0. 写在前面
这篇小教程适合对Python与NumPy有一定了解的朋友阅读,如果在阅读本文的源代码时感到吃力,请及时参照相关的教程或者文档。
1. 算法原理
k近邻算法(k Nearest Neighbor)可以简称为kNN。kNN是一个简单直观的算法,也是机器学习从业者入门首选的算法。先看一个简单的应用场景。
小例子
设有下表,命名为为表1
电影名称 | 打斗镜头数量 | 接吻镜头数量 | 电影类型 |
---|---|---|---|
foo1 | 3 | 104 | 爱情片 |
foo2 | 2 | 100 | 爱情片 |
foo3 | 1 | 81 | 爱情片 |
foo4 | 101 | 10 | 动作片 |
foo5 | 99 | 5 | 动作片 |
foo6 | 98 | 2 | 动作片 |
一个朴素的愿望是,能够根据打斗镜头与接吻镜头的数量来推测一部电影是属于爱情片还是动作片。具体而言,如果有一部电影的相关信息如下,命名为表2:
电影名称 | 打斗镜头数量 | 接吻镜头数量 |
---|---|---|
foo7 | 18 | 90 |
我们能否给出这部电影的类型?
解决方案
表1可以抽象为一个矩阵A与一个列向量x如下:
矩阵A
foo1 3 104
foo2 2 100
foo3 1 81
foo4 101 10
foo5 99 5
foo6 98 2
列向量x
爱情片
爱情片
爱情片
动作片
动作片
动作片
表2可以抽象为一个行向量a如下:
行向量a
foo7 18 90
显然,可以求矩阵A中每一个行向量与行向量a的欧式距离(本例中计算欧式距离时只考虑打斗镜头数量与接吻镜头数量两个分量),并按照距离由小到大排序,结果如下表,命名为表3:
电影名称 | 与未知电影之的距离 |
---|---|
foo2 | 18.7 |
foo3 | 19.2 |
foo1 | 20.5 |
foo4 | 115.3 |
foo5 | 117.4 |
foo6 | 118.9 |
此时选择前k个距离最小的电影及其所属的类型,结果如下表,命名为表4:
电影名称 | 类型 |
---|---|
foo2 | 爱情片 |
foo3 | 爱情片 |
foo4 | 爱情片 |
找出表4中出现次数最多的类型——“爱情片”,即kNN认为行向量a所属的类型为爱情片。
2. Python实现
代码的核心部分是如下函数,将其保存在文件中mykNN.py中。
import numpy as np
import operator as op
from collections import defaultdict
def classify(vec, dataSet, labels, k):
"""
要求dataSet为NumPy的array类型
vec: 参照行向量a
dataSet: 参照矩阵A
labels: 参照列向量x
k: kNN中选择前k小的行
"""
size = dataSet.shape[0]
assert size == len(labels) #断言,确保输入正确
tmp = (dataSet - vec)**2 #使用了NumPy的广播机制
tmp = tmp.sum(axis=1)
tmp = tmp.argsort()
tmpDict = defaultdict(int) #简化用于分组的代码
for i in range(k):
tmpDict[labels[tmp[i]]] += 1
return max(tmpDict.items(),key=op.itemgetter(1))[0]
3. 练手案例
我们使用第2小节的代码解决第1小节的问题。下面的代码文件保存为test.py,请确保test.py与mykNN.py文件位于同一个路径下。
import numpy as np
import mykNN as knn
if __name__ == "__main__":
dataSet = np.array([
[3, 104],
[2, 100],
[1, 81],
[101, 10],
[99, 5],
[98, 2]
])
labels = ["爱情片", "爱情片", "爱情片",
"动作片", "动作片", "动作片"]
k = 3
vec = [18, 90]
res = knn.classify(vec,dataSet,labels,k)
print(res)
4. 补充说明
真实的分类任务不会像我们的案例那样简单。
一般来说,第3小节中的dataSet与labels都会放在文件或者数据库中,并且未必是NumPy可以处理的数据类型。这时需要增加读文件或者读数据库并解析转换数据的一系列代码。
有时需要考虑对表格的不同字段归一化的问题。
以数据驱动的应用的开发需要关注kNN算法的正确率,这时需要增加判断正确率或者错误率的代码。
以上是关于k近邻算法的Python实现的主要内容,如果未能解决你的问题,请参考以下文章