机器学习——KNN

Posted siplips

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习——KNN相关的知识,希望对你有一定的参考价值。

导入类库

 1 import numpy as np
 2 from sklearn.neighbors import KNeighborsClassifier
 3 from sklearn.model_selection import train_test_split
 4 from sklearn.preprocessing import StandardScaler
 5 from sklearn.linear_model import LinearRegression
 6 from sklearn.metrics import r2_score
 7 from sklearn.datasets import load_iris
 8 import matplotlib.pyplot as plt
 9 import pandas as pd
10 import seaborn as sns
# 熵增益
# 熵越大,信息量越大,蕴含的不确定性越大
KNN
1.计算待预测值到所有点的距离
2.对所有距离排序
3.找出前K个样本里面类别最多的类,作为待预测值的类别

代码

 1 A = np.array([[1, 1], [1, 1.5], [0.5, 1.5]])
 2 B = np.array([[3.0, 3.0], [3.0, 3.5], [2.8, 3.1]])
 3 
 4 
 5 def knn_pre_norm(point):
 6     a_len = np.linalg.norm(point - A, axis=1)
 7     b_len = np.linalg.norm(point - B, axis=1)
 8     print(a_len.min())
 9     print(b_len.min())
10 
11 
12 def knn_predict_rev(point):
13     X = np.array([[1, 1], [1, 1.5], [0.5, 1.5], [3.0, 3.0], [3.0, 3.5], [2.8, 3.1]])
14     Y = np.array([0, 0, 0, 1, 1, 1])
15 
16     knn = KNeighborsClassifier(n_neighbors=2)
17     knn.fit(X, Y)
18 
19     print(knn.predict(np.array([[1.0, 3.0]])))
20 
21 
22 def iris_linear():
23     # 加载iris数据
24     li = load_iris()
25     # 散点图
26     # plt.scatter(li.data[:, 0], li.data[:, 1], c=li.target)
27     # plt.scatter(li.data[:, 2], li.data[:, 3], c=li.target)
28     # plt.show()
29     # 分割测试集和训练集,测试集占整个数据集的比例是0.25
30     x_train, x_test, y_train, y_test = train_test_split(li.data, li.target, test_size=0.25)
31     # 创建KNN分类,使用最少5个邻居作为类别判断标准
32     knn = KNeighborsClassifier(n_neighbors=5)
33     # 训练数据
34     knn.fit(x_train, y_train)
35     # 预测测试集
36     # print(knn.predict(x_test))
37     # 预测np.array([[6.3, 3, 5.2, 2.3]])
38     print(knn.predict(np.array([[6.3, 3, 5.2, 2.3]])))
39     # 预测np.array([[6.3, 3, 5.2, 2.3]])所属各个类别的概率
40     print(knn.predict_proba(np.array([[6.3, 3, 5.2, 2.3]])))
41 
42 
43 if __name__ == __main__:
44     # knn_predict_rev(None)
45     # knn_pre_norm(np.array([2.3,2.3]))
46     iris_linear()

 




以上是关于机器学习——KNN的主要内容,如果未能解决你的问题,请参考以下文章

《机器学习实战》--KNN

机器学习实战-第二章代码+注释-KNN

机器学习实战一(kNN)

万字详解·附代码机器学习分类算法之K近邻(KNN)

python_mmdt:从1到2--实现基于KNN的机器学*恶意代码分类器

机器学习算法_knn(福利)