聚类算法(K-Means)
Posted Young的编程日记
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了聚类算法(K-Means)相关的知识,希望对你有一定的参考价值。
K-Means
聚类算法属于无监督学习,这里复习下什么是监督学习与无监督学习。
监督式学习:由已有的数据包括输入输出,训练模型函数;然后把新的输入数据带入模型函数,预测数据输出;函数的输出可以是一个连续的值(称为回归分析),或是预测一个分类标签(称作分类)
无监督学习:无监督学习就是给你一些不知道输出的数据,然后给这些数据打上标签。
聚类就是属于无监督学习,我们输入一些数据,然后算法给这些数据的特征进行自动的分类。
K-Means算法的步骤是:首先我们输入想要分类的的个数,比如说我们想将所有的数据分为3类,那么就是输入3,之后算法会随机的生成3个初始化的中心点作为3个类的中心,这些初始化的中心点不一定是给的样本,可以是数据集没有的点。然后就会计算每个数据点到中心点的距离,距离最近就分到哪一类中。之后每一个类会根据当前的数据重新的找到新的中心点,最后就是重复以上的步骤,知道分类结果不再变化为止。
看下下面这个步骤图就知道了:
需要注意的是我们要手动的输入多少个类,我们该如何选择类的数量呢。
这里要引入一个概念就是组内平方和,这个应该是很好理解的,就是每个分类中的SS之和。
不同的类的个数,这个WCSS也会不同,从直觉上来说,WCSS应该是在类为1的时候最大,当类有数据点的个数时最小,为0。
那么WCSS根据类的个数的变化趋势如上,我们选择就是夹角最小的这个点,就是说这个点之前的速率和之后的速率的变化时最大的这一点,也叫手肘法则。在这里这个点就是3。
代码实现(这是一个根据商城用户的收入与消费评分进行的样本聚类问题):
# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
os.chdir(r'C:/Users/Yang/Desktop/Machine Learn A-Z/Machine Learning A-Z Chinese Template Folder/Part 4 - Clustering/Section 19 - K-Means Clustering/Section 19 - K-Means Clustering')
# Importing the dataset
dataset = pd.read_csv('Mall_Customers.csv')
X = dataset.iloc[:,-2:].values
#计算分类K值
from sklearn.cluster import KMeans
wcss = []
for i in range(1,10):
kmeans = KMeans(n_clusters=i,init='k-means++',max_iter=300,random_state=0)
kmeans.fit(X)
wcss.append(kmeans.inertia_)
plt.plot(range(1,10),wcss)
plt.ylabel('WCSS')
plt.xlabel('Number of Clusters')
plt.title('The Elbow Method')
plt.savefig('The Elbow Method',dpi=600)
plt.show()
我们选择K值为5,然后使用算法聚类并且画出来。
kmeans = KMeans(n_clusters=5,random_state=0)
y_pred = kmeans.fit_predict(X)
plt.scatter(x = X[:,0][y_pred==0],y = X[:,1][y_pred==0],c='red',label='Careful')
plt.scatter(x = X[:,0][y_pred==1],y = X[:,1][y_pred==1],c='blue',label='Standard')
plt.scatter(x = X[:,0][y_pred==2],y = X[:,1][y_pred==2],c='green',label='Target')
plt.scatter(x = X[:,0][y_pred==3],y = X[:,1][y_pred==3],c='cyan',label='Careless')
plt.scatter(x = X[:,0][y_pred==4],y = X[:,1][y_pred==4],c='magenta',label='Sensible')
plt.title('clusters of clients')
plt.xlabel('Annual Income')
plt.ylabel('Spending Score')
plt.legend()
plt.savefig('clusters of clients',dpi=600)
plt.show()
以上是关于聚类算法(K-Means)的主要内容,如果未能解决你的问题,请参考以下文章
为啥使用 k-means(来自 Scipy)聚类到两个片段的图像会显示两个以上不同的像素值?