sklearn.plot_tree 如何可视化分类任务的 class_labels?

Posted

技术标签:

【中文标题】sklearn.plot_tree 如何可视化分类任务的 class_labels?【英文标题】:sklearn.plot_tree how to visualize class_labels for classification task? 【发布时间】:2022-01-13 08:49:42 【问题描述】:

我最近在决策树上完成了我的试用代码。它工作得很好,除了一件事。正在绘制的树中不包含类名。我是不是做错了什么?

请看下面的代码和数据集的图片。

#Import Data#

import pandas as pd

data_set = pd.read_excel(r"C:\Users\User\Desktop\Tree.xlsx")

print(data_set.head())

#Set Features and Training Targets#

features_names=["Money","Debt"]
target_names=["Mood1", "Mood2", "Mood3"]

features = data_set[features_names]
targets = data_set[target_names]

print(features)
print(targets)

#Set Training Set and Test Set#

train_features = features[:10]
train_targets = targets[:10]

test_features = features[10:]
test_targets = targets[10:]

print (train_features)
print (train_targets)

print(test_features)
print(test_targets)

#Estimating Tree#

from sklearn.tree import DecisionTreeRegressor

dt = DecisionTreeRegressor(max_depth = 3)
dt = dt.fit(train_features, train_targets)

print(dt.score(train_features, train_targets))
print(dt.score(test_features, test_targets))

#Plotting the Tree#

from sklearn import tree
import matplotlib.pyplot as plt

tree.plot_tree(dt, feature_names=features_names, class_names=target_names, filled = True)
plt.show()

【问题讨论】:

【参考方案1】:

在回归任务中,可视化标签可能不起作用; documentation 声明 class_name 参数是“仅与分类相关”。

在这种情况下,您的目标变量Mood 可以是分类的,在单个列中表示它的值。完成后,您可以设置

tree.plot_tree(clf, class_names=True)

用于类名的符号表示

class_names = ['setosa', 'versicolor', 'virginica']

tree.plot_tree(clf, class_names=class_names)

具体的类名。

完整示例

import numpy as np
from matplotlib import pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)

# Symbolic class name representation
tree.plot_tree(clf, class_names=True)

# Specific class name representation
class_names = iris['target_names']

tree.plot_tree(clf, class_names=class_names)

【讨论】:

正确,但 OP 使用 DT regressor,而不是分类器。如果这也适用于回归设置,请做出相应的说明。 好吧,我已经开始对原始帖子发表评论说“此选项不可用于回归”,但后来我意识到我对此不确定,所以我放弃评论;但看来您确实已经确认了:) 我建议您使用此说明更新您的帖子(包括相关的文档链接)。

以上是关于sklearn.plot_tree 如何可视化分类任务的 class_labels?的主要内容,如果未能解决你的问题,请参考以下文章

如何可视化 k 最近邻分类器的测试样本?

软件测试第一次作业

数据可视化是啥意思?

可视化在 matplotlib/seaborn 中有意义的数值与分类数据

wordpress上面如何在导航栏上面添加子分类

如何使用 caret-GBM 解释/调整多项分类?