KNN算法

Posted xiaoxineryi

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了KNN算法相关的知识,希望对你有一定的参考价值。

机器学习实战之K-近邻算法:

  KNN算法,就是在已知数据集中,计算出离输入的需要预测的点最接近的K个点,然后通过这最近的K个点中哪种分类所占比最高,该预测点就是哪一种分类。

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
import os
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = [A,A,B,B]
    return group,labels

def classify0(inX,dataSet,labels,k):
    # 获得数据集的个数:有多少个数据
    dataSetSize  = dataSet.shape[0]
    # tile:可以用来重复数据
    # tile(inX,(dataSetSize,1)) 就是让inX这个数据重复dataSetSize遍 每次都单独一行
    diffMat = tile(inX,(dataSetSize,1))-dataSet
    # 计算距离
    sqDiffMat = diffMat**2
    sqDiatance = sqDiffMat.sum(axis=1)
    distances  = sqDiatance**0.5
    # argsort():对数组进行排序 并且返回排序的下标 默认是从小到大
    sortedDisIndicies = distances.argsort()
    classCount ={}
    for i in range(k):
        votelLabel  =labels[sortedDisIndicies[i]]
        classCount[votelLabel] = classCount.get(votelLabel,0)+1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def file2mat(filename):
    # 将文件转化成矩阵
    # 打开文件
    fr = open(filename)
    # 读取文件 获取文件行数:也就是数据集的个数
    arrayOfLines =fr.readlines()
    numberOfLines = len(arrayOfLines)
    # 因为测试数据中有三列数据,一列标签 所以后面是3 这里应该改成输入进参数 有更好的复用性
    returnMat = zeros((numberOfLines,3))
    classLabelVector=[]
    index = 0
    for line in arrayOfLines:
        # 通过分隔符来判断有几列 并记录对应数值
        line = line.strip()
        listFromLine = line.split(	)
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index +=1
    return returnMat,classLabelVector

def autoNorm(dataSet):
    # 归一化处理
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals-minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals

def datingTest():
    hoRatio = 0.1
    datingDataMat,datingLabels = file2mat("datingTestSet2.txt")
    normMat,ranges,minVals = autoNorm(datingDataMat)
    m  =normMat.shape[0]
    numTestVec = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVec):
        classidierResult = classify0(normMat[i,:],normMat[numTestVec:m,:],datingLabels[numTestVec:m],3)
        print("the classifier came back with : %d , the real answer is %d" %(classidierResult,datingLabels[i]))
        if(classidierResult != datingLabels[i]):
            errorCount+=1.0

    print("the total error rate is : %f" %(errorCount/float(numTestVec)))


def img2Mat(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int (lineStr[j])
    return returnVect


def handWritingTest():
    hwLabels = []
    trainingFileList = os.listdir(trainingDigits)
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split(.)[0]
        classNumStr = int(fileStr.split(_)[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]= img2Mat(trainingDigits/%s %fileNameStr)
    testFileList = os.listdir(testDigits)
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split(.)[0]
        classNumStr = int(fileStr.split(_)[0])
        vectorUnderTest = img2Mat(testDigits/%s %fileNameStr)
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)

        print("the classifier came back with : %d , the real answer is %d" %(classifierResult,classNumStr))
        if(classifierResult != classNumStr ):
            errorCount+=1.0
    print("
 the error rate is : %f" %(errorCount/float(mTest)))

 

    对应的代码和注解

 

以上是关于KNN算法的主要内容,如果未能解决你的问题,请参考以下文章

监督学习算法_k-近邻(kNN)分类算法_源代码

(理论和代码相结合)KNN(最近邻)算法⭐

分类-KNN算法(代码复现和可视化)

⭐ (理论和代码相结合)KNN(最近邻)算法——分类问题和回归问题都能做的算法

模式识别实验二:K近邻算法(KNN)

万字详解·附代码机器学习分类算法之K近邻(KNN)