数据挖掘干货(k-NN)
Posted 生信媛
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了数据挖掘干货(k-NN)相关的知识,希望对你有一定的参考价值。
what is k-NN ?
k-nearest neighbors algorithm (k-NN)是通过测量不同特征值之间的距离进行分类。它的的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
其实早在南北朝时期,我国古人就提出了该算法的核心 "近朱者赤,近墨者黑"。
简单地举个例子,假如我们要确定上图中的蓝色的点真正的颜色是什么,我们就划定一个范围,找到与它最近的9个邻居,在这9个邻居中有5个是绿色的4个是红色的,那么我们就可以说K=9时,X更接近于绿色。与它最近的27个点中14个是红色13个是绿色,X更接近于红色。由此看来,KNN算法的一般步骤:
- 计算测试数据与各个训练数据之间的距离;
- 按照距离的递增关系进行排序;
- 选取距离最小的K个点;
- 确定前K个点所在类别的出现频率;
- 返回前K个点中出现频率最高的类别作为测试数据的预测分类。
值得注意的是,在距离当中我们一般采用的是欧氏几何距离,如果说有特殊需求,我们也可以采取曼哈顿距离,还可以看到的是X的预测分类与K的取值有很大的关系。
Using
#!python
#coding:utf-8
#author:kim
#copyrights 2017 www.lowpitch.cn all rights reserved.
"""
You can find the original Code from the Offcial Site
http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
n_neighbors = 7
# import some data to play with
iris = datasets.load_iris()
# we only take the first two features. We could avoid this ugly
# slicing by using a two-dim dataset
X = iris.data[:, :2]
y = iris.target
h = 0.2 # step size in the mesh
# Create color maps
cmap_light = ListedColormap(['#fffaaa', '#aaffaa', '#ccaaff'])
cmap_bold = ListedColormap(['#00ffcc','#ff00cc', '#0099ff'])
# we create an instance of Neighbours Classifier and fit the data.
clf = neighbors.KNeighborsClassifier(n_neighbors, weights='distance')
clf.fit(X, y)
# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
edgecolor='k', s=20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("3-Class classification (k = %i, weights = '%s')"
% (n_neighbors, 'distance'))
plt.show()
Thanks
- Sklearn powered by Google
- wiki
以上是关于数据挖掘干货(k-NN)的主要内容,如果未能解决你的问题,请参考以下文章