Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString
Posted
技术标签:
【中文标题】Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString【英文标题】:Spark|ML|Random Forest|Load trained model from .txt of RandomForestClassificationModel. toDebugString 【发布时间】:2017-05-01 20:34:21 【问题描述】:使用 Spark 1.6 和 ML 库,我正在使用 toDebugString()
保存经过训练的 RandomForestClassificationModel
的结果:
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
val stringModel =rfModel.toDebugString
//save stringModel into a file in the driver in format .txt
所以我的想法是,以后读取文件.txt
,加载训练好的randomForest,有可能吗?
谢谢!
【问题讨论】:
【参考方案1】:那行不通。 ToDebugString 只是一个调试信息,用于了解它是如何计算的。
如果你想保留这个东西以备后用,你可以做和我们一样的事情,也就是(虽然我们是在纯 java 中)简单地序列化 RandomForestModel 对象。默认java序列化可能存在版本不兼容,所以我们使用Hessian来做。它通过版本更新工作 - 我们从 spark 1.6.1 开始,它仍然适用于 spark 2.0.2。
【讨论】:
【参考方案2】:如果您可以不坚持使用 ml,请使用 mllib 的实现:您使用 mllib 获得的 RandomForestModel 具有 save
函数。
【讨论】:
【参考方案3】:至少对于 Spark 2.1.0,您可以使用以下 Java(抱歉 - 没有 Scala)代码来做到这一点。但是,依赖可能会在没有通知的情况下更改的未记录格式可能不是最明智的想法。
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static java.nio.charset.StandardCharsets.US_ASCII;
/**
* RandomForest.
*/
public abstract class RandomForest
private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);
protected final List<Node> trees = new ArrayList<>();
/**
* @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
* @throws IOException
*/
public RandomForest(final URL model) throws IOException
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII)))
Node node;
while ((node = load(reader)) != null)
trees.add(node);
if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
private static Node load(final BufferedReader reader) throws IOException
final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
Node root = null;
final List<Node> stack = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null)
final String trimmed = line.trim();
//System.out.println(trimmed);
if (trimmed.startsWith("RandomForest"))
// skip the "Tree 1" line
reader.readLine();
else if (trimmed.startsWith("Tree"))
break;
else if (trimmed.startsWith("If"))
// extract feature index
final Matcher m = ifPattern.matcher(trimmed);
m.matches();
final int featureIndex = Integer.parseInt(m.group(1));
final String operator = m.group(2);
final String operand = m.group(3);
final Predicate<Float> predicate;
if ("<=".equals(operator))
predicate = new LessOrEqual(Float.parseFloat(operand));
else if (">".equals(operator))
predicate = new Greater(Float.parseFloat(operand));
else if ("in".equals(operator))
predicate = new In(parseFloatArray(operand));
else if ("not in".equals(operator))
predicate = new NotIn(parseFloatArray(operand));
else
predicate = null;
final Node node = new Node(featureIndex, predicate);
if (stack.isEmpty())
root = node;
else
insert(stack, node);
stack.add(node);
else if (trimmed.startsWith("Predict"))
final Matcher m = predictPattern.matcher(trimmed);
m.matches();
final Object node = Float.parseFloat(m.group(1));
insert(stack, node);
return root;
private static void insert(final List<Node> stack, final Object node)
Node parent = stack.get(stack.size() - 1);
while (parent.getLeftChild() != null && parent.getRightChild() != null)
stack.remove(stack.size() - 1);
parent = stack.get(stack.size() - 1);
if (parent.getLeftChild() == null) parent.setLeftChild(node);
else parent.setRightChild(node);
private static float[] parseFloatArray(final String set)
final StringTokenizer st = new StringTokenizer(set, ",");
final float[] floats = new float[st.countTokens()];
for (int i=0; st.hasMoreTokens(); i++)
floats[i] = Float.parseFloat(st.nextToken());
return floats;
public abstract float predict(final float[] features);
public String toDebugString()
try
final StringWriter sw = new StringWriter();
for (int i=0; i<trees.size(); i++)
sw.write("Tree " + i + ":\n");
print(sw, "", trees.get(0));
return sw.toString();
catch (IOException e)
throw new UncheckedIOException(e);
private static void print(final Writer w, final String indent, final Object object) throws IOException
if (object instanceof Number)
w.write(indent + "Predict: " + object + "\n");
else if (object instanceof Node)
final Node node = (Node) object;
// left node
w.write(indent + node + "\n");
print(w, indent + " ", node.getLeftChild());
w.write(indent + "Else\n");
print(w, indent + " ", node.getRightChild());
@Override
public String toString()
return getClass().getSimpleName() + "numTrees=" + trees.size() + "";
/**
* Node.
*/
protected static class Node
private final int featureIndex;
private final Predicate<Float> predicate;
private Object leftChild;
private Object rightChild;
public Node(final int featureIndex, final Predicate<Float> predicate)
Objects.requireNonNull(predicate);
this.featureIndex = featureIndex;
this.predicate = predicate;
public void setLeftChild(final Object leftChild)
this.leftChild = leftChild;
public void setRightChild(final Object rightChild)
this.rightChild = rightChild;
public Object getLeftChild()
return leftChild;
public Object getRightChild()
return rightChild;
public Object eval(final float[] features)
Object result = this;
do
final Node node = (Node)result;
result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
while (result instanceof Node);
return result;
@Override
public String toString()
return "If (feature " + featureIndex + " " + predicate + ")";
private static class LessOrEqual implements Predicate<Float>
private final float value;
public LessOrEqual(final float value)
this.value = value;
@Override
public boolean test(final Float f)
return f <= value;
@Override
public String toString()
return "<= " + value;
private static class Greater implements Predicate<Float>
private final float value;
public Greater(final float value)
this.value = value;
@Override
public boolean test(final Float f)
return f > value;
@Override
public String toString()
return "> " + value;
private static class In implements Predicate<Float>
private final float[] array;
public In(final float[] array)
this.array = array;
@Override
public boolean test(final Float f)
for (int i=0; i<array.length; i++)
if (array[i] == f) return true;
return false;
@Override
public String toString()
return "in " + Arrays.toString(array);
private static class NotIn implements Predicate<Float>
private final float[] array;
public NotIn(final float[] array)
this.array = array;
@Override
public boolean test(final Float f)
for (int i=0; i<array.length; i++)
if (array[i] == f) return false;
return true;
@Override
public String toString()
return "not in " + Arrays.toString(array);
要使用类进行分类,请使用:
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
/**
* RandomForestClassifier.
*/
public class RandomForestClassifier extends RandomForest
public RandomForestClassifier(final URL model) throws IOException
super(model);
@Override
public float predict(final float[] features)
final Map<Object, Integer> counts = new HashMap<>();
trees.stream().map(node -> node.eval(features))
.forEach(result ->
Integer count = counts.get(result);
if (count == null)
counts.put(result, 1);
else
counts.put(result, count + 1);
);
return (Float)counts.entrySet()
.stream()
.sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
.map(Map.Entry::getKey)
.findFirst().get();
对于回归:
import java.io.IOException;
import java.net.URL;
/**
* RandomForestRegressor.
*/
public class RandomForestRegressor extends RandomForest
public RandomForestRegressor(final URL model) throws IOException
super(model);
@Override
public float predict(final float[] features)
return (float)trees
.stream()
.mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
.average()
.getAsDouble();
【讨论】:
以上是关于Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString的主要内容,如果未能解决你的问题,请参考以下文章
4.Spark ML学习笔记—Spark ML决策树 (应用案例)随机森林GBDT算法ML 树模型参数详解 (本篇概念多)