K最近邻算法(kNN)

Posted 桓桓桓桓

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了K最近邻算法(kNN)相关的知识,希望对你有一定的参考价值。

“近朱者赤,近墨者黑”,这句话大家都听说过,意思是靠着朱砂的变红,靠着墨的变黑。比喻接近好人可以使人变好,接近坏人可以使人变坏,指客观环境对人有很大影响。在现实的生活中我们都知道,想要了解一个人,一个比较靠谱的办法就是从他身边的朋友来了解。

本文介绍的K最近邻(k-Nearest Neighbor,KNN)算法,其指导思想就是“近朱者赤,近墨者黑”,由你身边的朋友推断你的类别。kNN是一种分类方法,全称k-Nearest Neighbor,顾名思义,是对于给定的测试样本和基于某种度量距离的方式下,通过最靠近的k个训练样本来预测当前样本的分类结果。其中预测的方式也很简单,就是投票,至于投票方式可以根据工程中的实际情况来决定,可以单纯的计数,当然也可以加权。

kNN算法并不存在训练过程,其实就是单纯的记录训练样本,在需要预测的时候做计算即可,虽然很简单,但在历史上影响力也是极大的。


借用百度一张图来说明kNN算法过程,加入要预测图中Xu的分类结果。就预设一个距离值,我们只考虑以Xu为圆心以这个距离值为半径的圆内的已知训练样本,然后根据这些样本的投票结果来预测Xu属于w1类别,投票结果是4:1。

kNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。

kNN算法计算步骤如下:

1)算距离:给定测试对象,计算它与训练集中的每个对象的距离

2)找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻

3)做分类:根据这k个近邻归属的主要类别,来对测试对象分类


从上图中我们可以看到,图中的有两个类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形。而那个绿色的圆形是我们待分类的数据。

a、如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。

b、如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。

---------------------------------------------优雅的分割线-------------------------------------------------

下面是我实现的一个小案例,判断橙色点(待分类点)是属于哪个类型


我们假设k = 7,则待分类点属于B类型,即图中蓝色点的类别。

核心代码如下:

package KNN;

import java.util.ArrayList;

public class KNN 

	// K值类实例
	private KClass kc;

	int countclassA = 0;
	int countclassB = 0;
	int countclassC = 0;
	ArrayList<DistanceSum> distance_sum = new ArrayList<DistanceSum>();
	ArrayList<DistanceClassA> distance_classA = new ArrayList<DistanceClassA>();
	ArrayList<DistanceClassB> distance_classB = new ArrayList<DistanceClassB>();
	ArrayList<DistanceClassC> distance_classC = new ArrayList<DistanceClassC>();

	ArrayList<Node> nodes = new ArrayList<Node>();
	int Nodex[] =  2, 1, 3, 2, 5, 6,12, 6, 7, 2, 4, 8, 8, 5, 3, 
					2, 1, 8, 9,11, 8, 9, 9, 7, 7, 8 ,3, 3, 2, 8;
	
	int Nodey[] =  1, 1, 1, 4, 6, 2, 3, 7, 8, 4, 8, 9, 4, 3, 4, 
					7, 8, 2, 3, 4, 7, 9, 8, 3, 1, 1 ,5, 6, 5, 7;
	String Nodetype[] =  "A", "B", "B", "B", "A", "A", "A", "B", "C", "A", "B", "A", "C", "B", "B",
			 			  "A", "B", "B", "A", "A", "C", "C", "C", "A", "A", "A", "B", "B", "B", "C" ;
	int distance[] = new int[Nodex.length];
	// 需要一个数据结构来存储所有的距离种类;
	int distance_kind[] = new int[Nodex.length];
	int x, y;

	// 构造函数
	KNN(int x, int y, int k) 
		this.x = x;
		this.y = y;
		kc = new KClass();
		kc.setK(k);

		getDistance(); // 得到目标节点与训练集节点的距离数组
		initNode(); // 初始化节点列表
		nodeSort(); // 节点链表排序

		getDistanceArray(); // 存储数据集到目标节点的距离种类
		initDistanceList(); // 初始化距离-个数链表
	

	void KNNCompute() throws Throwable 
		knn(kc);
	

	// 初始化节点列表
	void initNode() 
		for (int i = 0; i < Nodex.length; i++) 
			Node node = new Node();
			node.setX(Nodex[i]);
			node.setY(Nodey[i]);
			node.setType(Nodetype[i]);
			node.setDistance(distance[i]);
			nodes.add(node);
		
	

	// 得到目标节点与训练集节点的距离数组
	void getDistance() 
		for (int i = 0; i < Nodex.length; i++) 
			try 
				distance[i] = (x - Nodex[i]) * (x - Nodex[i]) + (y - Nodey[i])
						* (y - Nodey[i]);
			 catch (Exception e) 
				e.printStackTrace();
			

		
	

	// 用选择排序对节点链表按属性距离升序排序,方便后面取出节点
	void nodeSort() 
		Node tmpnode;
		for (int i = 0; i < nodes.size() - 1; i++) 
			for (int j = i + 1; j > 0; j--) 
				if (nodes.get(j).getDistance() < nodes.get(j - 1).getDistance()) 
					tmpnode = nodes.get(j);
					nodes.remove(j);
					nodes.add(j - 1, tmpnode);
				
			
		
	

	// 实现用数组存储距离种类的方法,并在这里初始化距离-个数链表
	void getDistanceArray() 
		try 
			distance_kind[0] = nodes.get(0).getDistance();
		 catch (NullPointerException e) 
			System.out.println("该链表为空!");
		

		for (int i = 1; i < nodes.size(); i++) 
			int distance = nodes.get(i).getDistance();
			int length = Tool.getActualLength(distance_kind);
			for (int j = length; j >= 0;) 
				if (0 == j) 
					distance_kind[length] = distance;
					break;
				 else 
					if (distance != distance_kind[j - 1]) 
						j--;
					 else 
						break;
					
				
			
		
	

	// 距离-个数链表初始化方法
	void initDistanceList() 

		DistanceSum ds;
		DistanceClassA dc1;
		DistanceClassB dc2;
		DistanceClassC dc3;
		int length = Tool.getActualLength(distance_kind);

		// 初始化总的 距离-个数链表
		for (int i = 0; i < length; i++) 
			int count_sum = 0;
			for (int j = 0; j < nodes.size(); j++) 
				if (nodes.get(j).getDistance() == distance_kind[i]) 
					count_sum++;
				
			
			ds = new DistanceSum();
			ds.setDistance(distance_kind[i]);
			ds.setCount(count_sum);
			distance_sum.add(ds);
		

		// 初始化类型为1的 距离-个数链表
		for (int i = 0; i < length; i++) 
			int count_class1 = 0;
			for (int j = 0; j < nodes.size(); j++) 
				if ("A".endsWith(nodes.get(j).getType())) 
					if (nodes.get(j).getDistance() == distance_kind[i]) 
						count_class1++;
					
				
			
			if (count_class1 > 0) 
				dc1 = new DistanceClassA();
				dc1.setDistance(distance_kind[i]);
				dc1.setCount(count_class1);
				distance_classA.add(dc1);
			
		

		// 初始化类型为2的 距离-个数链表
		for (int i = 0; i < length; i++) 
			int count_class2 = 0;
			for (int j = 0; j < nodes.size(); j++) 
				if ("B".equals(nodes.get(j).getType())) 
					if (nodes.get(j).getDistance() == distance_kind[i]) 
						count_class2++;
					
				
			
			if (count_class2 > 0) 
				dc2 = new DistanceClassB();
				dc2.setDistance(distance_kind[i]);
				dc2.setCount(count_class2);
				distance_classB.add(dc2);
			

		

		// 初始化类型为3的 距离-个数链表
		for (int i = 0; i < length; i++) 
			int count_class3 = 0;
			for (int j = 0; j < nodes.size(); j++) 
				if ("C".equals(nodes.get(j).getType())) 
					if (nodes.get(j).getDistance() == distance_kind[i]) 
						count_class3++;
					
				
			
			if (count_class3 > 0) 
				dc3 = new DistanceClassC();
				dc3.setDistance(distance_kind[i]);
				dc3.setCount(count_class3);
				distance_classC.add(dc3);
			
		
	

	// 对某个距离的不用类别的个数进行统计判断
	String statClassWinner(int class1_count, int class2_count, int class3_count) 

		if (class1_count == class2_count) 
			if (class1_count < class3_count) 
				return "C";
			 else 
				return "";
			
		 else if (class2_count == class3_count) 
			if (class2_count < class1_count) 
				return "A";
			 else 
				return "";
			
		 else if (class3_count == class1_count) 
			if (class3_count < class2_count) 
				return "B";
			 else 
				return "";
			
		 else if (class1_count > class2_count) 
			if (class1_count > class3_count) 
				return "A";
			 else 
				return "C";
			
		 else 
			if (class2_count > class3_count) 
				return "B";
			 else 
				return "C";
			
		
	

	// 统计指定距离上给类别的数目
	void statClassCount(int distance) 

		int length = Tool.getActualLength(distance_kind);
		for (int i = 0; i < length; i++) 
			// 找到该距离的位置
			if (distance == distance_kind[i]) 
				// 从第一种距离统计到第i种距离
				for (int j = 0; j <= i; j++) 
					// 统计第一种类型的有这个距离的数目
					for (int c1 = 0; c1 < distance_classA.size(); c1++) 
						if (distance_classA.get(c1).getDistance() == distance_kind[j]) 
							countclassA += distance_classA.get(c1).getCount();
						
					

					// 统计第二种类型的有这些距离的数目
					for (int c2 = 0; c2 < distance_classB.size(); c2++) 
						if (distance_classB.get(c2).getDistance() == distance_kind[j]) 
							countclassB += distance_classB.get(c2).getCount();
						
					

					// 统计第三种类型的有这些距离的数目
					for (int c3 = 0; c3 < distance_classC.size(); c3++) 
						if (distance_classC.get(c3).getDistance() == distance_kind[j]) 
							countclassC += distance_classC.get(c3).getCount();
						
					
				
			
		
	

	// KNN算法实现
	private void knn(KClass kc) throws Throwable 

		// 首先判断距离最近的节点数是否超出K个,如果超出k个,对这一距离的节点的类别进行统计,
		// 否则就继续进行下一个距离的节点数判断
		// 得到k个距离目标节点最近的节点
		int k = kc.getK();
		String type = "";
		int distance_count = 0;
		int count = 0;
		int length = Tool.getActualLength(distance_kind);

		for (; count < length; count++) 
			int difs = 0;
			int distance = distance_kind[count];
			distance_count += distance_sum.get(count).getCount();

			// 距离个数满足k值
			if (distance_count >= k) 
				// 先统计该距离上各类别的数目
				statClassCount(distance);
				// 判断该目标节点的类别
				type = statClassWinner(countclassA, countclassB, countclassC);
				if ("".equals(type)) 
					System.out.println("在K为" + k + "下不适合划分目标节点!");
					countclassA = 0;
					countclassB = 0;
					countclassC = 0;
					kc.setK(k + 1);
					knn(kc);
				 else 
					System.out.println("在K值为" + k + "下,该节点被划分到类型" + type);
					countclassA = 0;
					countclassB = 0;
					countclassC = 0;
					System.exit(0);
				
			
		
		if (length == count) 
			System.out.println("该k值过大!");
		
	

	public static void main(String[] args) throws Throwable 
		KNN knn = new KNN(4, 4, 7);
		knn.KNNCompute();
	

代码下载地址:http://download.csdn.net/detail/u013043346/9834687

以上是关于K最近邻算法(kNN)的主要内容,如果未能解决你的问题,请参考以下文章

KNN(最近邻)分类算法

KNN近邻算法

K-近邻算法(KNN)

K-近邻算法简介

k-近邻(KNN) 算法预测签到位置

K-近邻算法(KNN)