KNN算法的实现

Posted 自嗨锅

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了KNN算法的实现相关的知识,希望对你有一定的参考价值。

KNN算法是机器学习经典十大算法之一,简单易懂。这里给出KNN的实现,由两个版本:

1.机器学习实战上作者的实现版本,我自己又敲了一遍感觉还是蛮有收获的;

2.用自己的理解的一个实现,主要的区别就是效率没有第一个高,因为第一个大量使用矩阵向量的运算,速度比较快,还有就是作者的代码比较简介好看。自己的代码就是比较好懂。

 

1.《机器学习实战》的代码

  1 # -*- coding: utf-8 -*-
  2 
  3 ‘‘‘
  4  function: 根据《实战》实现KNN算法,用于约会网站进行匹配
  5  date: 2017.8.5
  6 ‘‘‘
  7 
  8 from numpy import *
  9 import matplotlib.pyplot as plt
 10 import operator
 11 
 12 #产生数据集和标签
 13 def createDataSet():
 14     group = array([[1.0,1.1], [1.0,1.0], [0,0], [0,0.1]])
 15     labels = [A, A, B, B]
 16     return group, labels
 17 
 18 #knn分类算法
 19 def classify0(inX, dataSet, labels, k):
 20     dataSetSize = dataSet.shape[0] #返回m,n,shape[0] == m 行数,代表由几个训练数据
 21 
 22     #计算距离;将测试向量 inX 在咧方向上重复一次,行方向重复m次,与对应的训练数据相减
 23     diffMat = tile(inX, (dataSetSize, 1)) - dataSet 
 24 
 25     sqDiffMat = diffMat**2 #平方的意思,不能用^
 26     sqDistances = sqDiffMat.sum(axis=1) #axis=1表示按行相加,axis=0表示案列相加
 27     distances = sqDistances**0.5 #测试点距离所有训练点的距离向量集合
 28 
 29     sortedDistIndices = distances.argsort() #距离从小到大排序,返回的是排序后的下表组成的数组;
 30     classCount = {}
 31 
 32     for i in range(k):
 33         voteIlabel = labels[sortedDistIndices[i]]
 34         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #get函数返回value,若不存在则返回0
 35     
 36     #items()返回key,value,itemgetter(1),按照第二个元素进行排序,从大到小
 37     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True)
 38 
 39     return sortedClassCount[0][0]
 40 
 41 #将文本记录转换为nmpy的解析程序
 42 def file2matrix(filename):
 43     fr = open(filename)
 44     arrayOLines = fr.readlines()
 45     numberOfLines = len(arrayOLines)
 46     returnMat = zeros((numberOfLines,3)) #返回的矩阵 是: m*3,3个特征值
 47     classLabelVector = []
 48     index = 0
 49     for line in arrayOLines:
 50         line = line.strip() #去掉换行符
 51         listFromLine = line.split(\t) #将每一行用tab分割
 52         returnMat[index,:] = listFromLine[0:3] #每行前三个是特征
 53         classLabelVector.append(int(listFromLine[-1])) #最后一列是标签,并且告诉list是int类型
 54         index += 1
 55     return returnMat, classLabelVector
 56 
 57 #draw the point
 58 def drawPoint(datingDatamat, datingLabels):
 59     fig = plt.figure()
 60     ax = fig.add_subplot(111)
 61 
 62     #总共三个特征,可以选择其中任意两个来画图,看哪个区分效果好
 63     ax.scatter(datingDatamat[:,0],datingDatamat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
 64     #ax.scatter(datingDatamat[:,1],datingDatamat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
 65     plt.show()
 66 
 67 
 68 #对数据集归一化处理 autoNormanization
 69 def autoNorm(dataSet):
 70     minVals = dataSet.min(0) #get the minimal row
 71     maxVals = dataSet.max(0) #get the maximal row
 72     ranges = maxVals - minVals + ones_like(maxVals) #个人感觉这里可能由除〇的风险,加一效果可能好一点
 73     normDataSet = zeros(shape(dataSet)) #the same size of original dataset: m*n
 74     m = dataSet.shape[0]
 75     normDataSet = dataSet - tile(minVals, (m,1)) #repeat the minimal row m times on the row
 76     normDataSet = dataSet / tile(ranges, (m,1))
 77     return normDataSet, ranges, minVals
 78 
 79 #分类器针对约会网站的测试代码
 80 def datingClassTest():
 81     haRatio = 0.10
 82     datingDataMat, datingLabels = file2matrix(datingTestSet2.txt)
 83     normMat, ranges, minVals = autoNorm(datingDataMat)
 84     m = normMat.shape[0] #获得矩阵的行数,数据集的个数
 85     numTestVecs = int(m*haRatio) #用十分之一的数据来做测试,剩下的作为训练数据
 86     errorCount = 0.0
 87     for i in range(numTestVecs): #对每个测试用例开始计算
 88         #0~numtestvecs都是用来测试的,每次测试一行,从numtestvecs~m都是训练数据
 89         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
 90         print(the classifier came back with : %d, the real answer is : %d 91          % (classifierResult,datingLabels[i]))
 92         if(classifierResult != datingLabels[i]):
 93             errorCount += 1
 94     print(the total error rate is: %f % (errorCount / float(numTestVecs)))
 95 
 96 
 97 #让用户输入自己的一些信息,用输入的信息进行分类
 98 def classifyPerson():
 99     resultList = [not at all, in small doses, in large doses]
100     percentTats = float(input(persentage of time spent playing video games?))
101     ffMiles = float(input(frequent fliers miles each year?))
102     iceCream = float(input(liters of ice cream consumed per year?))
103     datingDatamat, datingLabels = file2matrix(datingTestSet2.txt)
104     normMat, ranges, minVals = autoNorm(datingDatamat)
105     inArr = array([percentTats, ffMiles, iceCream])
106     #对输入的一个测试数据进行归一化然后分类
107     classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)
108     print(you will probably like this person: , resultList[classifierResult - 1])
109     print(classifierResult)
110     
111 # #测试分类效果和加载数据的程序
112 # group, labels = createDataSet()
113 # result = classify0([0,0], group, labels, 3)
114 # print(result)
115 
116 # #get numpy metrix from file
117 # returnMat, classLabelVector = file2matrix(‘datingTestSet2.txt‘)
118 # print(returnMat)
119 # print(classLabelVector)
120 
121 # #auto normanization
122 # normDataSet, ranges, minVals = autoNorm(returnMat)
123 
124 # #测试我们的分类器
125 # datingClassTest()
126 
127 #请求用户输入几个特征,然后进行分类,判断喜爱程度
128 classifyPerson()
129 
130 #draw
131 #drawPoint(normDataSet, classLabelVector)

 

2.自己的实现

  1 # -*- coding: utf-8 -*-
  2 
  3 ‘‘‘
  4 function: 使用二维数据点集(x,y)来实现knn算法核心功能
  5 note: 计算欧式距离还可以用 numpy.linalg.norm(array1 - array2) 必须是array
  6 date: 2017.07.23
  7 ‘‘‘
  8 from numpy import *
  9 from random import *
 10 import matplotlib.pyplot as plt
 11 
 12 #创建用来训练的数据集和用来测试的数据集,及标注信息,先用100个点训练,20个点测试
 13 def createData():
 14     trainingList0 = [[(1.0,1.0) for j in range(1)] for i in range(50)]
 15     trainingList1 = [[(1.0,1.0) for j in range(1)] for i in range(50)]
 16     trainingLabel,trainingLabe0 = [], []
 17     for i in range(50):
 18         x = randint(1,10)   #产生1,10之间的随机浮点数,类别0
 19         y = randint(1,10)
 20         trainingList0[i] = [(x,y)]
 21         trainingLabel.append(0) #属于类别0
 22     for j in range(50):
 23         x = randint(10,20) #产生十位数,10-20,类别1
 24         y = randint(10,20)
 25         trainingList1[j] = [(x,y)]
 26         trainingLabel.append(1)   #属于类别1
 27     trainingList = trainingList0 + trainingList1
 28 
 29     #产生测试的数据集
 30     testList0 = [[(1.0,1.0) for j in range(1)] for i in range(10)]
 31     testList1 = [[(1.0,1.0) for j in range(1)] for i in range(10)]
 32     testLabel = []
 33     for i in range(10):
 34         x = randint(4,9)
 35         y = randint(2,10)
 36         testList0[i] = [(x,y)]
 37         testLabel.append(0)
 38     for j in range(10):
 39         x = randint(11,19)
 40         y = randint(10,18)
 41         testList1[j] = [(x,y)]
 42         testLabel.append(1)
 43     testList = testList0 + testList1
 44     print(trainingList)
 45     return trainingList, trainingLabel, testList, testLabel
 46 
 47 #对测试数据集和训练数据集之间的距离进行计算,采用欧式距离 d = ((x1-x2)^2 + (y1-y2)^2)^0.5
 48 def calculateKNN(trainingList, trainingLabel, testList, testLabel, k):
 49     #20行,100列距离,每行代表一个测试点距离100个训练点的距离,初始化为0
 50     d = [[0 for j in range(100)] for i in range(20)] 
 51     #kPointDisList = [1000] * k #初始化k个最近距离为1000
 52     for i in range(20):
 53         for j in range(100):
 54             d[i][j] = getDistance(testList[i], trainingList[j]) #计算距离
 55             #这里判断求得的距离是否是最小的K个之一,若是则记录下j,更新此测试点的标签类别
 56             #kPointDisList = updateKPoint(d[i][j], kPointDisList)
 57     testLable = getKMin(d,k,trainingLabel)
 58     print(testLabel)
 59     return testLabel
 60 
 61 #计算列表里面的K个最小值,找出对应的标签类别号
 62 def getKMin(dList,k,trainingLabel):
 63     testLabel = []
 64     sortedIndexList = argsort(dList)  #返回dLlist从小到大排列的索引的数组
 65     print(sortedIndexList) 
 66     print(len(sortedIndexList))
 67     for x in range(20):
 68         type1 = 0 
 69         type0 = 0
 70         for i in range(k):  #计算最近的K个点,按照类别的多少进行投票表决
 71             if trainingLabel[sortedIndexList[x][i]] == 1:
 72                 type1 += 1
 73             else:
 74                 type0 += 1
 75         if type1 > type0:
 76             testLabel.append(1)
 77         else:
 78             testLabel.append(0)
 79     return testLabel
 80 
 81 #跟新最近的K个点坐标,标签
 82 def updateKPoint(d, kPointDisList):
 83     if d < max(kPointDisList):
 84         kPointDisList[kPointDisList.index(max(kPointDisList))] = d
 85     else:
 86         pass
 87     return kPointDisList
 88 
 89 #计算两个点的欧式距离,因为有float所以只能用**
 90 def getDistance(a, b):
 91     return ((a[0][0] - b[0][0])**2 + (a[0][1] - b[0][1])**2)**0.5
 92 
 93 #计算KNN分类结果的正确率
 94 def getCorrectRate(testLabel, resultLabel):
 95     correctNum = 0
 96     for i in range(len(testLabel)):
 97         if testLabel[i] == resultLabel[i]:
 98             correctNum += 1
 99     return correctNum / float(len(testLabel))
100 
101 #把分类结果用图形画出来
102 def drawKNN(trainingList, testList):
103     #产生画图数据
104     x1,x2,y1,y2 = [],[],[],[]
105     for i in range(len(trainingList)):
106         x1.append(trainingList[i][0][0])
107         y1.append(trainingList[i][0][1])
108     for i in range(len(testList)):
109         x2.append(testList[i][0][0])
110         y2.append(testList[i][0][1])
111     #创建一个绘图对象
112     fig = plt.figure()
113     ax1 = fig.add_subplot(111)
114     #添加坐标图修饰 
115     plt.xlabel(times(m))
116     plt.ylabel(money(y))
117     plt.title(first pic)
118     #画散点图,maker代表散点图形状,c代表颜色
119     pl1 = ax1.scatter(x1,y1,c = r,marker = .)
120     pl2 = ax1.scatter(x2,y2,c = b,marker = +)
121     #设置图标
122     plt.legend([pl1,pl2],(train,test))
123     #显示图形
124     plt.show()
125     #保存图形到本地
126     plt.savefig(test.png)
127 #系统的总体控制逻辑
128 def testKNN():
129     #设定系统的K值
130     k = 3
131     trainingList, trainingLabel, testList, testLabel = createData()
132     resultLabel = calculateKNN(trainingList, trainingLabel, testList, testLabel, k)
133     correctRate = getCorrectRate(testLabel, resultLabel)
134     print(correctRate is  =   + str(correctRate))
135     drawKNN(trainingList, testList)
136 
137 testKNN()
138 #drawKNN()

 

以上是关于KNN算法的实现的主要内容,如果未能解决你的问题,请参考以下文章

python实现简单knn算法

数据挖掘——KNN算法的实现

自己实现的简易的knn算法

2. KNN和KdTree算法实现

Python实现KNN算法

机器学习实战笔记(Python实现)-01-K近邻算法(KNN)