k近邻算法
Posted qi-lin
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了k近邻算法相关的知识,希望对你有一定的参考价值。
介绍
k近邻算法(KNN)属于监督学习的分类算法,通过测量不同特征值之间的距离进行分类,算法过程如下
- 计算数据点与已知数据集中每个点的距离
- 对距离从小到大进行排序
- 选取前k个距离值
- 确定前k个距离值所在类别的出现的概率
- 将前k个点出现频率最高的类别作为当前数据的预测分类
主要代码如下
def classfiy(inData, dataSet, labels, k):
dataSize = dataSet.shape[0] # 得到数组的行维度,即数据的个数
# 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
sortIndex = distance.argsort() # 返回数组值从小到大的索引值
classCount = {}
for i in range(k): # 只对前k个计数
headLabel = labels[sortIndex[i]]
classCount[headLabel] = classCount.get(headLabel, 0) + 1 # 统计前k个中出现标签的次数
# 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortCount[0][0] # 返回第一个的标签
其中距离计算,通过公式,如((x_{1},y_{1})(x_{2},y_{2}))两点的距离d为(d=sqrt{(x_{1}-x_{2})^2+(y_{1}-y_{2})^2})
用KNN识别数字图片中的数字
只是个玩具程序
收集数据
每个数字准备了10张图片,分别存在digit中的以各个数字命名的文件夹下
又为每个数据准备了5张图片,以同样的规则存在digit2的各个文件夹下
准备数据
缩放图像
采用了pillow中的resize函数,同一将图像缩放为50*50
newImg = img.resize((50, 50))
二值化图像
开始想直接通过convet(‘1‘)直接将图像二值化,但出现了很多噪音
所以通过以下程序将图像二值化。其中230为设定的阀值,多次尝试,发现230效果较好
for i in range(rows):
for j in range(cols):
if (imgArray[i, j] <= 230):
imgArray[i, j] = 0
else:
imgArray[i, j] = 255
转化为一维向量
将读取的处理后的图片的像素值转化为一维向量
测试
通过读取测试集中的数据,进行预测,和实际的类别比对,查看正确率
程序
from PIL import Image
from numpy import *
import os
import operator
#缩放为相同大小
def toSame(img):
newImg = img.resize((50, 50))
return newImg
#二值化处理
def toBinarry(img):
imgArray = array(img)
rows, cols = imgArray.shape
for i in range(rows):
for j in range(cols):
if (imgArray[i, j] <= 230):
imgArray[i, j] = 0
else:
imgArray[i, j] = 255
return imgArray
#读取每个文件夹下的每张图片
def readImage(filePath):
dataList = []
labels = []
for i in range(10):
imagePath = filePath + '/' + str(i)
files = os.listdir(imagePath)
for j in files:
labels.append(j.split('_')[0])#因为每张图片采用‘数字_第几张的命名方式’,所以通过下横线分割,取得前面的作为图片的分类标签
img = Image.open(imagePath + '/' + j).convert('L')#先灰度化处理
imgArray = toBinarry(toSame(img))
dataList.append(imgArray.ravel())#转变为一维后加入列表
dataSet = array(dataList)
return dataSet, labels
#分类算法
def classfiy(inData, dataSet, labels, k):
dataSize = dataSet.shape[0] # 得到数组的行维度,即数据的个数
# 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
sortIndex = distance.argsort() # 返回数组值从小到大的索引值
classCount = {}
for i in range(k): # 只对前k个计数
headLabel = labels[sortIndex[i]]
classCount[headLabel] = classCount.get(headLabel, 0) + 1 # 统计前k个中出现标签的次数
# 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortCount[0][0] # 返回第一个的标签
# 进行测试
dataSet, labels = readImage('./digit')
dataSet2, labels2 = readImage('./digit2')
n = 0
for i in range(len(dataSet2)):
predict = classfiy(dataSet2[i], dataSet, labels, 10)
print(predict + ' ' + labels2[i])
if (predict == labels2[i]):
n = n + 1
# 查看准确率
print(n / len(dataSet2))
运行结果
发现准确率只有0.62
总结
- 准确率如此低,可能是数据不足,也可能对图像处理不好。在二值化时,效果其实并不完美。也可能需要对图像进行一些裁剪。在二值化时,本程序也只适合一些浅色底子的数字图片
- 采用不同的k,预测的效果也是不同,也需要找到一个最佳的k
其它
- 在处理数据时,通常用到的归一化
def toNormal(dataSet):
# 归一化
min = dataSet.min(0)
max = dataSet.max(0)
# 公式normal=(x-min)/(max-min)
normalArray = (dataSet - tile(min, (dataSet.shape[0], 1))) / tile(max - min, (dataSet.shape[0], 1))
return normalArray
- 开始二值化时,想通过降噪进一步处理图像,后来没用到
- 识别验证码的文章,里面提到了降噪算法https://www.jb51.net/article/141428.htm
- 二值化图像算法的总结https://www.cnblogs.com/Zhi-Z/p/8906426.html
def toClear(imgArray):
rows, cols = imgArray.shape
for y in range(1, cols - 1):
for x in range(1, rows - 1):
count = 0
if imgArray[x, y - 1] == 255: # 上
count = count + 1
if imgArray[x, y + 1] == 255: # 下
count = count + 1
if imgArray[x - 1, y] == 255: # 左
count = count + 1
if imgArray[x + 1, y] == 255: # 右
count = count + 1
if imgArray[x - 1, y - 1] == 255: # 左上
count = count + 1
if imgArray[x - 1, y + 1] == 255: # 左下
count = count + 1
if imgArray[x + 1, y - 1] == 255: # 右上
count = count + 1
if imgArray[x + 1, y + 1] == 255: # 右下
count = count + 1
if count > 4:
imgArray[x, y] = 255
return imgArray
- 关于对图片的裁剪,切割https://www.bbsmax.com/A/E35pl2B5vX/
- 一篇不错的K-近邻(KNN)算法实现手写数字识别的文章https://blog.csdn.net/zzZ_CMing/article/details/78938107
以上是关于k近邻算法的主要内容,如果未能解决你的问题,请参考以下文章