K-近邻算法入门
Posted zjq-115
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了K-近邻算法入门相关的知识,希望对你有一定的参考价值。
K-近邻算法的直观理解就是:给定一个训练集合,对于新的实例,在训练集合中找到k个与该实例最近的邻居,然后根据“少数服从多数”原则判断该实例归属于哪一类,又称“随大流”
K-近邻算法的三大要素:K值得选取,邻居距离度量,分类决策的制定。
(1)K值选取:通常采用交叉验证选取最优的K值(自己了解)
(2)邻居距离度量:根据不同的应用场景选取相应的距离度量。常见的距离度量有欧几里得距离、曼哈顿距离、马氏距离。同时要注意的是归一化机制。
(3)分类决策制定:一般分为平等投票表决原则和加权投票原则。
import operator import csv import math import random def loadDataSet(filename,split,trainingSet=[],testSet=[]): #读取本地数据# with open(filename,‘r‘) as csvfile: lines=csv.reader(csvfile) dataset=list(lines) for x in range(len(dataset)-1): for y in range (4): dataset[x][y]=float(dataset[x][y]) if random.random()<split: trainingSet.append(dataset[x]) else: testSet.append(dataset[x]) def EuclidDist(instance1,instance2,len): #求欧几里得距离# distance=0.0 for x in range(len): distance+=pow((instance1[x]-instance2[x]),2) return math.sqrt(distance) def getNeighbors(trainSet,testInstance,k): #获取最近邻居# distance=[] length=len(testInstance)-1 for x in range(len(trainSet)): dist=EuclidDist(testInstance,trainSet[x],length) distance.append((trainSet[x],dist)) distance.sort(key=operator.itemgetter(1)) #列表的sort(key)方法用来根据关键字排序 neighbors=[] for x in range(k): neighbors.append(distance[x][0]) return neighbors def getClass(neighbors): #分类与评估函数# classVotes={} for x in range(len(neighbors)): instance_class=neighbors[x][-1] if instance_class in classVotes: classVotes[instance_class]+=1 else: classVotes[instance_class]=1 sortedVotes=sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True) return sortedVotes[0][0] def getAccuracy(testSet,predictions): #预测正确率计算# correct=0 for x in range(len(testSet)): if testSet[x][-1]==predictions[x]: correct+=1 return (correct/float(len(testSet)))*100.0 def main(): trainingSet=[] testSet=[] split=0.7 loadDataSet(‘iris.data.csv‘,split,trainingSet,testSet) print(‘训练集合:‘+repr(len(trainingSet))) print(‘测试集合:‘+repr(len(testSet))) predictions=[] k=3 for x in range(len(testSet)): neighbors=getNeighbors(trainingSet,testSet[x],k) result=getClass(neighbors) predictions.append(result) print(‘>预测=‘+repr(result)+‘,实际=‘+repr(testSet[x][-1])) accuracy=getAccuracy(testSet,predictions) print(‘精确度为:‘+repr(accuracy)+‘%‘) main()
针对此代码中的数据来源为UCI机器学习库中的鸢尾花卉数据集,可以直接获取(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data),也可以下载我转换好的CSV文件(链接:https://pan.baidu.com/s/1YSLhrPMn3RflGE8VDGGbHQ 提取码:42se )
本次范例属于“自己动手丰衣足食”,每个函数都自己实现,可以在入门阶段对K-近邻算法流程有个初步认识,在有了一定基础之后,我们就没有必要重造轮子,可以使用常见的机器学习算法,毕竟其专业性远远目前超过我们自己的程序。例如scikit-learn模块。
以上是关于K-近邻算法入门的主要内容,如果未能解决你的问题,请参考以下文章