从 Sparklyr 中提取和可视化模型树
Posted
技术标签:
【中文标题】从 Sparklyr 中提取和可视化模型树【英文标题】:Extract and Visualize Model Trees from Sparklyr 【发布时间】:2019-04-07 00:13:12 【问题描述】:有人对如何将 sparklyr 的 ml_decision_tree_classifier、ml_gbt_classifier 或 ml_random_forest_classifier 模型中的树信息转换为 a.) 其他 R 树相关库可以理解的格式和(最终)b.) 可视化有任何建议非技术消费的树木?这将包括从向量汇编期间生成的替换字符串索引值转换回实际特征名称的能力。
以下代码是从a sparklyr blog post大量复制的,用于提供示例:
library(sparklyr)
library(dplyr)
# If needed, install Spark locally via `spark_install()`
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)
# split the data into train and validation sets
iris_data <- iris_tbl %>%
sdf_partition(train = 2/3, validation = 1/3, seed = 123)
iris_pipeline <- ml_pipeline(sc) %>%
ft_dplyr_transformer(
iris_data$train %>%
mutate(Sepal_Length = log(Sepal_Length),
Sepal_Width = Sepal_Width ^ 2)
) %>%
ft_string_indexer("Species", "label")
iris_pipeline_model <- iris_pipeline %>%
ml_fit(iris_data$train)
iris_vector_assembler <- ft_vector_assembler(
sc,
input_cols = setdiff(colnames(iris_data$train), "Species"),
output_col = "features"
)
random_forest <- ml_random_forest_classifier(sc,features_col = "features")
# obtain the labels from the fitted StringIndexerModel
iris_labels <- iris_pipeline_model %>%
ml_stage("string_indexer") %>%
ml_labels()
# IndexToString will convert the predicted numeric values back to class labels
iris_index_to_string <- ft_index_to_string(sc, "prediction", "predicted_label",
labels = iris_labels)
# construct a pipeline with these stages
iris_prediction_pipeline <- ml_pipeline(
iris_pipeline, # pipeline from previous section
iris_vector_assembler,
random_forest,
iris_index_to_string
)
# fit to data and make some predictions
iris_prediction_model <- iris_prediction_pipeline %>%
ml_fit(iris_data$train)
iris_predictions <- iris_prediction_model %>%
ml_transform(iris_data$validation)
iris_predictions %>%
select(Species, label:predicted_label) %>%
glimpse()
根据here 的建议反复试验后,我能够以“if/else”格式将底层决策树的公式转换为字符串:
model_stage <- iris_prediction_model$stages[[3]]
spark_jobj(model_stage) %>% invoke(., "toDebugString") %>% cat()
##print out below##
RandomForestClassificationModel (uid=random_forest_classifier_5c6a1934c8e) with 20 trees
Tree 0 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 2 <= 4.95)
If (feature 3 <= 1.65)
Predict: 0.0
Else (feature 3 > 1.65)
If (feature 0 <= 1.7833559100698644)
Predict: 0.0
Else (feature 0 > 1.7833559100698644)
Predict: 2.0
Else (feature 2 > 4.95)
If (feature 2 <= 5.05)
If (feature 1 <= 6.505000000000001)
Predict: 2.0
Else (feature 1 > 6.505000000000001)
Predict: 0.0
Else (feature 2 > 5.05)
Predict: 2.0
Tree 1 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 3 <= 1.75)
If (feature 1 <= 5.0649999999999995)
If (feature 3 <= 1.05)
Predict: 0.0
Else (feature 3 > 1.05)
If (feature 0 <= 1.8000241202036602)
Predict: 2.0
Else (feature 0 > 1.8000241202036602)
Predict: 0.0
Else (feature 1 > 5.0649999999999995)
If (feature 0 <= 1.8000241202036602)
Predict: 0.0
Else (feature 0 > 1.8000241202036602)
If (feature 2 <= 5.05)
Predict: 0.0
Else (feature 2 > 5.05)
Predict: 2.0
Else (feature 3 > 1.75)
Predict: 2.0
Tree 2 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 0 <= 1.7664051342320237)
Predict: 0.0
Else (feature 0 > 1.7664051342320237)
If (feature 3 <= 1.45)
If (feature 2 <= 4.85)
Predict: 0.0
Else (feature 2 > 4.85)
Predict: 2.0
Else (feature 3 > 1.45)
If (feature 3 <= 1.65)
If (feature 1 <= 8.125)
Predict: 2.0
Else (feature 1 > 8.125)
Predict: 0.0
Else (feature 3 > 1.65)
Predict: 2.0
Tree 3 (weight 1.0):
If (feature 0 <= 1.6675287895788053)
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
Predict: 0.0
Else (feature 0 > 1.6675287895788053)
If (feature 3 <= 1.75)
If (feature 3 <= 1.55)
If (feature 1 <= 7.025)
If (feature 2 <= 4.55)
Predict: 0.0
Else (feature 2 > 4.55)
Predict: 2.0
Else (feature 1 > 7.025)
Predict: 0.0
Else (feature 3 > 1.55)
If (feature 2 <= 5.05)
Predict: 0.0
Else (feature 2 > 5.05)
Predict: 2.0
Else (feature 3 > 1.75)
Predict: 2.0
Tree 4 (weight 1.0):
If (feature 2 <= 4.85)
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
Predict: 0.0
Else (feature 2 > 4.85)
If (feature 2 <= 5.05)
If (feature 0 <= 1.8484238118815566)
Predict: 2.0
Else (feature 0 > 1.8484238118815566)
Predict: 0.0
Else (feature 2 > 5.05)
Predict: 2.0
Tree 5 (weight 1.0):
If (feature 2 <= 1.65)
Predict: 1.0
Else (feature 2 > 1.65)
If (feature 3 <= 1.65)
If (feature 0 <= 1.8325494627242664)
Predict: 0.0
Else (feature 0 > 1.8325494627242664)
If (feature 2 <= 4.95)
Predict: 0.0
Else (feature 2 > 4.95)
Predict: 2.0
Else (feature 3 > 1.65)
Predict: 2.0
Tree 6 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 2 <= 5.05)
If (feature 3 <= 1.75)
Predict: 0.0
Else (feature 3 > 1.75)
Predict: 2.0
Else (feature 2 > 5.05)
Predict: 2.0
Tree 7 (weight 1.0):
If (feature 3 <= 0.55)
Predict: 1.0
Else (feature 3 > 0.55)
If (feature 3 <= 1.65)
If (feature 2 <= 4.75)
Predict: 0.0
Else (feature 2 > 4.75)
Predict: 2.0
Else (feature 3 > 1.65)
If (feature 2 <= 4.85)
If (feature 0 <= 1.7833559100698644)
Predict: 0.0
Else (feature 0 > 1.7833559100698644)
Predict: 2.0
Else (feature 2 > 4.85)
Predict: 2.0
Tree 8 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 3 <= 1.85)
If (feature 2 <= 4.85)
Predict: 0.0
Else (feature 2 > 4.85)
If (feature 0 <= 1.8794359129669855)
Predict: 2.0
Else (feature 0 > 1.8794359129669855)
If (feature 3 <= 1.55)
Predict: 0.0
Else (feature 3 > 1.55)
Predict: 0.0
Else (feature 3 > 1.85)
Predict: 2.0
Tree 9 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 2 <= 4.95)
Predict: 0.0
Else (feature 2 > 4.95)
Predict: 2.0
Tree 10 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 2 <= 4.95)
Predict: 0.0
Else (feature 2 > 4.95)
If (feature 2 <= 5.05)
If (feature 3 <= 1.55)
Predict: 2.0
Else (feature 3 > 1.55)
If (feature 3 <= 1.75)
Predict: 0.0
Else (feature 3 > 1.75)
Predict: 2.0
Else (feature 2 > 5.05)
Predict: 2.0
Tree 11 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 2 <= 5.05)
If (feature 2 <= 4.75)
Predict: 0.0
Else (feature 2 > 4.75)
If (feature 3 <= 1.75)
Predict: 0.0
Else (feature 3 > 1.75)
Predict: 2.0
Else (feature 2 > 5.05)
Predict: 2.0
Tree 12 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 3 <= 1.75)
If (feature 3 <= 1.35)
Predict: 0.0
Else (feature 3 > 1.35)
If (feature 0 <= 1.695573522904327)
Predict: 0.0
Else (feature 0 > 1.695573522904327)
If (feature 1 <= 8.125)
Predict: 2.0
Else (feature 1 > 8.125)
Predict: 0.0
Else (feature 3 > 1.75)
If (feature 0 <= 1.7833559100698644)
Predict: 0.0
Else (feature 0 > 1.7833559100698644)
Predict: 2.0
Tree 13 (weight 1.0):
If (feature 3 <= 0.55)
Predict: 1.0
Else (feature 3 > 0.55)
If (feature 2 <= 4.95)
If (feature 2 <= 4.75)
Predict: 0.0
Else (feature 2 > 4.75)
If (feature 0 <= 1.8000241202036602)
If (feature 1 <= 9.305)
Predict: 2.0
Else (feature 1 > 9.305)
Predict: 0.0
Else (feature 0 > 1.8000241202036602)
Predict: 0.0
Else (feature 2 > 4.95)
Predict: 2.0
Tree 14 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 3 <= 1.65)
If (feature 3 <= 1.45)
Predict: 0.0
Else (feature 3 > 1.45)
If (feature 2 <= 4.95)
Predict: 0.0
Else (feature 2 > 4.95)
Predict: 2.0
Else (feature 3 > 1.65)
If (feature 0 <= 1.7833559100698644)
If (feature 0 <= 1.7664051342320237)
Predict: 2.0
Else (feature 0 > 1.7664051342320237)
Predict: 0.0
Else (feature 0 > 1.7833559100698644)
Predict: 2.0
Tree 15 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 3 <= 1.75)
If (feature 2 <= 4.95)
Predict: 0.0
Else (feature 2 > 4.95)
If (feature 1 <= 8.125)
Predict: 2.0
Else (feature 1 > 8.125)
If (feature 0 <= 1.9095150692894909)
Predict: 0.0
Else (feature 0 > 1.9095150692894909)
Predict: 2.0
Else (feature 3 > 1.75)
Predict: 2.0
Tree 16 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 0 <= 1.7491620461964392)
Predict: 0.0
Else (feature 0 > 1.7491620461964392)
If (feature 3 <= 1.75)
If (feature 2 <= 4.75)
Predict: 0.0
Else (feature 2 > 4.75)
If (feature 0 <= 1.8164190316151556)
Predict: 2.0
Else (feature 0 > 1.8164190316151556)
Predict: 0.0
Else (feature 3 > 1.75)
Predict: 2.0
Tree 17 (weight 1.0):
If (feature 0 <= 1.695573522904327)
If (feature 2 <= 1.65)
Predict: 1.0
Else (feature 2 > 1.65)
Predict: 0.0
Else (feature 0 > 1.695573522904327)
If (feature 2 <= 4.75)
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
Predict: 0.0
Else (feature 2 > 4.75)
If (feature 3 <= 1.75)
If (feature 1 <= 5.0649999999999995)
Predict: 2.0
Else (feature 1 > 5.0649999999999995)
If (feature 3 <= 1.65)
Predict: 0.0
Else (feature 3 > 1.65)
Predict: 0.0
Else (feature 3 > 1.75)
Predict: 2.0
Tree 18 (weight 1.0):
If (feature 3 <= 0.8)
Predict: 1.0
Else (feature 3 > 0.8)
If (feature 3 <= 1.65)
Predict: 0.0
Else (feature 3 > 1.65)
If (feature 0 <= 1.7833559100698644)
Predict: 0.0
Else (feature 0 > 1.7833559100698644)
Predict: 2.0
Tree 19 (weight 1.0):
If (feature 2 <= 2.5)
Predict: 1.0
Else (feature 2 > 2.5)
If (feature 2 <= 4.95)
If (feature 1 <= 8.705)
Predict: 0.0
Else (feature 1 > 8.705)
If (feature 2 <= 4.85)
Predict: 0.0
Else (feature 2 > 4.85)
If (feature 0 <= 1.8164190316151556)
Predict: 2.0
Else (feature 0 > 1.8164190316151556)
Predict: 0.0
Else (feature 2 > 4.95)
Predict: 2.0
如您所见,这种格式不太适合传递给我见过的许多可视化决策树图形的漂亮方法之一(例如 revolution analytics 或 statmethods)
【问题讨论】:
【参考方案1】:截至今天(Spark 2.4.0 版本已经批准并等待官方公告)您最好的选择*,不涉及复杂的 3rd 方工具(例如,您可以看看 MLeap),可能是save the model 和回读the specification:
ml_stage(iris_prediction_model, "random_forest") %>%
ml_save("/tmp/model")
rf_spec <- spark_read_parquet(sc, "rf", "/tmp/model/data/")
结果将是具有以下架构的 Spark DataFrame
:
rf_spec %>%
spark_dataframe() %>%
invoke("schema") %>% invoke("treeString") %>%
cat(sep = "\n")
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
提供有关所有节点和拆分的信息。
可以使用列元数据检索特征映射:
meta <- iris_predictions %>%
select(features) %>%
spark_dataframe() %>%
invoke("schema") %>% invoke("apply", 0L) %>%
invoke("metadata") %>%
invoke("getMetadata", "ml_attr") %>%
invoke("getMetadata", "attrs") %>%
invoke("json") %>%
jsonlite::fromJSON() %>%
dplyr::bind_rows() %>%
copy_to(sc, .) %>%
rename(featureIndex = idx)
meta
# Source: spark<?> [?? x 2]
featureIndex name
* <int> <chr>
1 0 Sepal_Length
2 1 Sepal_Width
3 2 Petal_Length
4 3 Petal_Width
以及您已经检索到的标签映射:
labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
copy_to(sc, .)
最后你可以结合所有这些:
full_rf_spec <- rf_spec %>%
spark_dataframe() %>%
invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>%
sdf_register() %>%
select(-split, -impurityStats) %>%
left_join(meta, by = "featureIndex") %>%
left_join(labels, by = "prediction")
full_rf_spec
# Source: spark<?> [?? x 12]
treeID id prediction impurity gain leftChild rightChild featureIndex
* <int> <int> <dbl> <dbl> <dbl> <int> <int> <int>
1 0 0 1 0.636 0.379 1 2 2
2 0 1 1 0 -1 -1 -1 -1
3 0 2 0 0.440 0.367 3 8 2
4 0 3 0 0.0555 0.0269 4 5 3
5 0 4 0 0 -1 -1 -1 -1
6 0 5 0 0.5 0.5 6 7 0
7 0 6 0 0 -1 -1 -1 -1
8 0 7 2 0 -1 -1 -1 -1
9 0 8 2 0.111 0.0225 9 12 2
10 0 9 2 0.375 0.375 10 11 1
# ... with more rows, and 4 more variables: leftCategoriesOrThreshold <list>,
# numCategories <int>, name <chr>, label <chr>
由treeID
收集和分隔,应该提供足够的信息** 来模拟树状对象(您可以通过查看rpart::rpart.object
文档和/或unclass
来很好地理解所需的结构rpart
模型。tree::tree
将需要更少的工作,但它的绘图实用程序远非令人印象深刻),并建立一个体面的情节。
另一种方法是使用 Sparklyr2PMML 将数据导出到 PMML 并使用此表示。
您也可以查看How do I visualise / plot a decision tree in Apache Spark (PySpark 1.4.1)?,它建议使用第三方 Python 包来解决同样的问题。
如果您不需要任何花哨的东西,您可以使用igraph
创建一个粗略的情节:
library(igraph)
gframe <- full_rf_spec %>%
filter(treeID == 0) %>% # Take the first tree
mutate(
leftCategoriesOrThreshold = ifelse(
size(leftCategoriesOrThreshold) == 1,
# Continuous variable case
concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)),
# Categorical variable case. Decoding variables might be involved
# but can be achieved if needed, using column metadata or indexer labels
concat("in ", concat_ws(",", leftCategoriesOrThreshold), "")
),
name = coalesce(name, label)) %>%
select(
id, label, impurity, gain,
leftChild, rightChild, leftCategoriesOrThreshold, name) %>%
collect()
vertices <- gframe %>% rename(label = name, name = id)
edges <- gframe %>%
transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>%
union_all(gframe %>% select(from = id, to = rightChild)) %>%
filter(to != -1)
g <- igraph::graph_from_data_frame(edges, vertices = vertices)
plot(
g, layout = layout_as_tree(g, root = c(1)),
vertex.shape = "rectangle", vertex.size = 45)
* 它应该会在不久的将来得到改进,新引入的格式无关 ML 编写器 API(它已经支持选定模型的 PMML 编写器。希望新的模型和格式会随之而来)。
** 如果您使用分类特征,您可能希望将 leftCategoriesOrThreshold
映射到相应的索引级别。
如果特征向量包含分类变量,jsonlite::fromJSON()
的输出将包含nominal
组。例如,如果您为具有三个级别的列 foo
编制索引,在第一个位置组装,它将是这样的:
$nominal
vals idx name
1 a, b, c 1 foo
其中vals
列是可变长度向量的列表。
length(meta$nominal$vals[[1]])
[1] 3
标签对应此结构的索引,因此在示例中:
a
有标签 0.0(不是标签是双精度浮点数,编号从 0.0 开始)
b
的标签为 1.0
以此类推,如果您使用leftCategoriesOrThreshold
拆分等于c(0.0, 2.0)
,则表示拆分位于标签"a", "c"
上。
还请注意,如果存在分类数据,您可能必须在调用 copy_to
之前对其进行处理 - 目前看来它不支持复杂字段。
在 Spark purrr 应该可以正常工作)。在 Spark 2.4(sparklyr
AFAIK 尚不支持)中,使用 Spark 的 JSON 阅读器直接读取元数据并使用其高阶函数进行映射可能更容易。
【讨论】:
这看起来很棒!谢谢!事实上,我正在使用分类特征。这显然是我问的,但如果您对如何使用列元数据或索引器按结果类别着色有一些建议,我将非常感激。 对于着色,您可以使用V(g)$color <- ...
设置节点的颜色。
我还添加了一些关于名义元数据结构的注释。以上是关于从 Sparklyr 中提取和可视化模型树的主要内容,如果未能解决你的问题,请参考以下文章
使用 SparklyR“选择”从 Spark 数组列中提取元素