如何从 sklearn AgglomerativeClustering 遍历树?
Posted
技术标签:
【中文标题】如何从 sklearn AgglomerativeClustering 遍历树?【英文标题】:How to traverse a tree from sklearn AgglomerativeClustering? 【发布时间】:2015-02-07 19:11:17 【问题描述】:我有一个 numpy 文本文件数组位于:https://github.com/alvations/anythingyouwant/blob/master/WN_food.matrix
这是术语之间的距离矩阵,我的术语列表如下:http://pastebin.com/2xGt7Xjh
我使用下面的代码生成了一个层次聚类:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
matrix = np.loadtxt('WN_food.matrix')
n_clusters = 518
model = AgglomerativeClustering(n_clusters=n_clusters,
linkage="average", affinity="cosine")
model.fit(matrix)
要获得每个术语的集群,我可以这样做:
for term, clusterid in enumerate(model.labels_):
print term, clusterid
但是如何遍历 AgglomerativeClustering 输出的树?
是否可以将其转换为 scipy dendrogram (http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.cluster.hierarchy.dendrogram.html)?之后我如何遍历树状图?
【问题讨论】:
documentation 建议查看model
的children_
属性。
我使用了 children_,它给了我两个节点的列表,它不是遍历而是返回子节点,我不知道那是什么,并且子节点的节点数超出了我的节点数.. .
n
对象的完整层次聚类生成具有2n - 1
节点的树。正如文档所说:“小于 n_samples 的值是指树的叶子。更大的值 i 表示具有子 children_[i - n_samples] 的节点”。这应该是遍历树的足够信息。
@jme,请原谅小白,但是文档是什么意思?
树中的每个节点都分配有一个 ID,i
。如果 ID 小于输入对象的数量n_samples
,则该节点是叶子。否则,它是一个内部节点,它会连接另外两个节点。由节点i
连接的两个节点在children_[i - n_samples]
中找到。顺便说一句,如果您的目标是将其转换为 scipy 树状图,为什么不直接使用 scipy.cluster.hierarchy.linkage
而不是 sklearn
?
【参考方案1】:
除了 A.P. 的回答之外,这里的代码将为您提供一个会员字典。 member[node_id] 给出所有数据点索引(从零到 n)。
on_split 是对 A.P 集群的简单重新格式化,它给出了在 node_id 拆分时形成的两个集群。
up_merge 告诉 node_id 合并到什么以及必须合并什么 node_id 才能合并。
ii = itertools.count(data_x.shape[0])
clusters = ['node_id': next(ii), 'left': x[0], 'right':x[1] for x in fit_cluster.children_]
import copy
n_points = data_x.shape[0]
members = i:[i] for i in range(n_points)
for cluster in clusters:
node_id = cluster["node_id"]
members[node_id] = copy.deepcopy(members[cluster["left"]])
members[node_id].extend(copy.deepcopy(members[cluster["right"]]))
on_split = c["node_id"]: [c["left"], c["right"]] for c in clusters
up_merge = c["left"]: "into": c["node_id"], "with": c["right"] for c in clusters
up_merge.update(c["right"]: "into": c["node_id"], "with": c["left"] for c in clusters)
【讨论】:
【参考方案2】:我已经为 sklearn.cluster.ward_tree 回答了一个类似的问题: How do you visualize a ward tree from sklearn.cluster.ward_tree?
AgglomerativeClustering 在 children_ 属性中以相同的方式输出树。这是对 AgglomerativeClustering 的病房树问题中代码的改编。它为树的每个节点以 (node_id, left_child, right_child) 的形式输出树的结构。
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import itertools
X = np.concatenate([np.random.randn(3, 10), np.random.randn(2, 10) + 100])
model = AgglomerativeClustering(linkage="average", affinity="cosine")
model.fit(X)
ii = itertools.count(X.shape[0])
['node_id': next(ii), 'left': x[0], 'right':x[1] for x in model.children_]
https://***.com/a/26152118
【讨论】:
有没有办法知道节点中有哪些项目? 它与集群标签model.labels_
有什么关系?
节点号也是树的每个叶子的数据向量的索引。例如 'left': 1, 'right': 2, 'node_id': 10 节点 10 有叶子 1 和 2 作为子节点。 X[1] 是叶 1 的数据向量。
您也可以使用dict(enumerate(model.children_, model.n_leaves_))
,它会为您提供一个字典,其中每个键是节点的 ID,值是其子节点的 ID 对。以上是关于如何从 sklearn AgglomerativeClustering 遍历树?的主要内容,如果未能解决你的问题,请参考以下文章
如何从管道中的 sklearn TFIDF Vectorizer 返回数据帧?