使用 matplotlib 标记 K-means 聚类数据点

Posted

技术标签:

【中文标题】使用 matplotlib 标记 K-means 聚类数据点【英文标题】:Labeling K-means cluster data points with matplotlib 【发布时间】:2017-09-01 10:06:39 【问题描述】:

我从 .csv 文件 (databoth.csv) 中提取了以下数据,并使用 matplotlib 执行了 k-means 聚类。数据为 3 列(国家、出生率、预期寿命)。

我需要帮助才能输出: 属于每个集群的国家/地区的数量。 属于每个集群的国家/地区列表。 每个集群的平均预期寿命和出生率。

这是我的代码:

import csv
import matplotlib.pyplot as plt
import sys
import pylab as plt
import numpy as np
plt.ion()


#K-Means clustering implementation
# data = set of data points
# k = number of clusters
# maxIters = maximum number of iterations executed k-means
def kMeans(data, K, maxIters = 10, plot_progress = None):

    centroids = data[np.random.choice(np.arange(len(data)), K), :]
    for i in range(maxIters):
        # Cluster Assignment step
        C = np.array([np.argmin([np.dot(x_i-y_k, x_i-y_k) for y_k in 
        centroids]) for x_i in data])
        # Move centroids step
        centroids = [data[C == k].mean(axis = 0) for k in range(K)]
        if plot_progress != None: plot_progress(data, C, np.array(centroids))
    return np.array(centroids) , C


# Calculates euclidean distance between
# a data point and all the available cluster
# centroids.
def euclidean_dist(data, centroids, clusters):
    for instance in data:
        mu_index = min([(i[0], np.linalg.norm(instance-centroids[i[0]])) \
                        for i in enumerate(centroids)], key=lambda t:t[1])[0]
    try:
        clusters[mu_index].append(instance)
    except KeyError:
        clusters[mu_index] = [instance]

# If any cluster is empty then assign one point
# from data set randomly so as to not have empty
# clusters and 0 means.
for cluster in clusters:
    if not cluster:
        cluster.append(data[np.random.randint(0, len(data), size=1)].flatten().tolist())

return clusters


# this function reads the data from the specified files
def csvRead(file):
    np.genfromtxt('dataBoth.csv', delimiter=',')




# function to show the results on the screen in form of 3 clusters
def show(X, C, centroids, keep = False):
    import time
    time.sleep(0.5)
    plt.cla()
    plt.plot(X[C == 0, 0], X[C == 0, 1], '*b',
     X[C == 1, 0], X[C == 1, 1], '*r',
     X[C == 2, 0], X[C == 2, 1], '*g')
plt.plot(centroids[:,0],centroids[:,1],'*m',markersize=20)
plt.draw()
if keep :
    plt.ioff()
    plt.show()

# generate 3 cluster data
data = csvRead('dataBoth.csv')
m1, cov1 = [9, 8], [[1.5, 2], [1, 2]]
m2, cov2 = [5, 13], [[2.5, -1.5], [-1.5, 1.5]]
m3, cov3 = [3, 7], [[0.25, 0.5], [-0.1, 0.5]]
data1 = np.random.multivariate_normal(m1, cov1, 250)
data2 = np.random.multivariate_normal(m2, cov2, 180)
data3 = np.random.multivariate_normal(m3, cov3, 100)
X = np.vstack((data1,np.vstack((data2,data3))))
np.random.shuffle(X)


# calls to the functions
# first to find centroids using k-means
centroids, C = kMeans(X, K = 3, plot_progress = show)
#second to show the centroids on the graph
show(X, C, centroids, True)

【问题讨论】:

【参考方案1】:

也许你可以使用annotate: http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.annotate

更多示例: http://matplotlib.org/users/annotations.html#plotting-guide-annotation

这将允许在每个点附近都有一个文本标签。

或者你可以使用post中的颜色

【讨论】:

嗨@Dadep 我编辑了我的问题,以便更清楚地了解所需的帮助。 所以你应该阅读你在clusters 中的每个cluster 中输入的所有instance,并对它们进行统计。我稍后会尝试编辑我的帖子 请帮我大忙!

以上是关于使用 matplotlib 标记 K-means 聚类数据点的主要内容,如果未能解决你的问题,请参考以下文章

k-means 聚类数据:如何标记新传入的数据

matplotlib 笔记:使用TeX标记

K-Means算法:基于聚类的无监督机器学习算法

python使用matplotlib可视化查看matplotlib中常用的线条形式(line style)和数据点标记形状(marker)

Matplotlib:从头开始制作彩色标记图例

Matplotlib - 标记每个 bin