K-近邻算法入门

Posted zjq-115

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了K-近邻算法入门相关的知识,希望对你有一定的参考价值。

  K-近邻算法的直观理解就是:给定一个训练集合,对于新的实例,在训练集合中找到k个与该实例最近的邻居,然后根据“少数服从多数”原则判断该实例归属于哪一类,又称“随大流”

K-近邻算法的三大要素:K值得选取,邻居距离度量,分类决策的制定。

(1)K值选取:通常采用交叉验证选取最优的K值(自己了解)

(2)邻居距离度量:根据不同的应用场景选取相应的距离度量。常见的距离度量有欧几里得距离、曼哈顿距离、马氏距离。同时要注意的是归一化机制。

(3)分类决策制定:一般分为平等投票表决原则和加权投票原则。

import operator
import csv
import math
import random

def loadDataSet(filename,split,trainingSet=[],testSet=[]):
    #读取本地数据#
    with open(filename,r) as csvfile:
        lines=csv.reader(csvfile)
        dataset=list(lines)
        for x in range(len(dataset)-1):
            for y in range (4):
                dataset[x][y]=float(dataset[x][y])
            if random.random()<split:
                trainingSet.append(dataset[x])
            else:
                testSet.append(dataset[x])

def EuclidDist(instance1,instance2,len):
    #求欧几里得距离#
    distance=0.0
    for x in range(len):
        distance+=pow((instance1[x]-instance2[x]),2)
    return math.sqrt(distance)


def getNeighbors(trainSet,testInstance,k):
    #获取最近邻居#
    distance=[]
    length=len(testInstance)-1
    for x in range(len(trainSet)):
        dist=EuclidDist(testInstance,trainSet[x],length)
        distance.append((trainSet[x],dist))
    distance.sort(key=operator.itemgetter(1))
    #列表的sort(key)方法用来根据关键字排序
    neighbors=[]
    for x in range(k):
        neighbors.append(distance[x][0])
    return neighbors

def getClass(neighbors):
    #分类与评估函数#
     classVotes={}
     for x in range(len(neighbors)):
         instance_class=neighbors[x][-1]
         if instance_class in classVotes:
             classVotes[instance_class]+=1
         else:
             classVotes[instance_class]=1
         sortedVotes=sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True)
     return sortedVotes[0][0]

def getAccuracy(testSet,predictions):
    #预测正确率计算#
    correct=0
    for x in range(len(testSet)):
        if testSet[x][-1]==predictions[x]:
            correct+=1
    return (correct/float(len(testSet)))*100.0

def main():
    trainingSet=[]
    testSet=[]
    split=0.7
    loadDataSet(iris.data.csv,split,trainingSet,testSet)
    print(训练集合:+repr(len(trainingSet)))
    print(测试集合:+repr(len(testSet)))
    predictions=[]
    k=3
    for x in range(len(testSet)):
        neighbors=getNeighbors(trainingSet,testSet[x],k)
        result=getClass(neighbors)
        predictions.append(result)
        print(>预测=+repr(result)+,实际=+repr(testSet[x][-1]))
    accuracy=getAccuracy(testSet,predictions)
    print(精确度为:+repr(accuracy)+%)

main()

针对此代码中的数据来源为UCI机器学习库中的鸢尾花卉数据集,可以直接获取(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data),也可以下载我转换好的CSV文件(链接:https://pan.baidu.com/s/1YSLhrPMn3RflGE8VDGGbHQ 提取码:42se )

本次范例属于“自己动手丰衣足食”,每个函数都自己实现,可以在入门阶段对K-近邻算法流程有个初步认识,在有了一定基础之后,我们就没有必要重造轮子,可以使用常见的机器学习算法,毕竟其专业性远远目前超过我们自己的程序。例如scikit-learn模块。

以上是关于K-近邻算法入门的主要内容,如果未能解决你的问题,请参考以下文章

机器学习入门之K近邻法

机器学习入门之K近邻法

手写数字识别的k-近邻算法实现

web安全之机器学习入门——3.1 KNN/k近邻算法

机器学习实战—— k-近邻算法

基本分类方法——KNN(K近邻)算法