Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString

Posted

技术标签:

【中文标题】Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString【英文标题】:Spark|ML|Random Forest|Load trained model from .txt of RandomForestClassificationModel. toDebugString 【发布时间】:2017-05-01 20:34:21 【问题描述】:

使用 Spark 1.6 和 ML 库,我正在使用 toDebugString() 保存经过训练的 RandomForestClassificationModel 的结果:

 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
 val stringModel =rfModel.toDebugString
 //save stringModel into a file in the driver in format .txt 

所以我的想法是,以后读取文件.txt,加载训练好的randomForest,有可能吗?

谢谢!

【问题讨论】:

【参考方案1】:

那行不通。 ToDebugString 只是一个调试信息,用于了解它是如何计算的。

如果你想保留这个东西以备后用,你可以做和我们一样的事情,也就是(虽然我们是在纯 java 中)简单地序列化 RandomForestModel 对象。默认java序列化可能存在版本不兼容,所以我们使用Hessian来做。它通过版本更新工作 - 我们从 spark 1.6.1 开始,它仍然适用于 spark 2.0.2。

【讨论】:

【参考方案2】:

如果您可以不坚持使用 ml,请使用 mllib 的实现:您使用 mllib 获得的 RandomForestModel 具有 save 函数。

【讨论】:

【参考方案3】:

至少对于 Spark 2.1.0,您可以使用以下 Java(抱歉 - 没有 Scala)代码来做到这一点。但是,依赖可能会在没有通知的情况下更改的未记录格式可能不是最明智的想法。

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
 * RandomForest.
 */
public abstract class RandomForest 

    private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);

    protected final List<Node> trees = new ArrayList<>();

    /**
     * @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
     * @throws IOException
     */
    public RandomForest(final URL model) throws IOException 
        try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) 
            Node node;
            while ((node = load(reader)) != null) 
                trees.add(node);
            
        
        if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
        if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
    

    private static Node load(final BufferedReader reader) throws IOException 
        final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
        final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
        Node root = null;
        final List<Node> stack = new ArrayList<>();
        String line;
        while ((line = reader.readLine()) != null) 
            final String trimmed = line.trim();
            //System.out.println(trimmed);
            if (trimmed.startsWith("RandomForest")) 
                // skip the "Tree 1" line
                reader.readLine();
             else if (trimmed.startsWith("Tree")) 
                break;
             else if (trimmed.startsWith("If")) 
                // extract feature index
                final Matcher m = ifPattern.matcher(trimmed);
                m.matches();
                final int featureIndex = Integer.parseInt(m.group(1));
                final String operator = m.group(2);
                final String operand = m.group(3);
                final Predicate<Float> predicate;
                if ("<=".equals(operator)) 
                    predicate = new LessOrEqual(Float.parseFloat(operand));
                 else if (">".equals(operator)) 
                    predicate = new Greater(Float.parseFloat(operand));
                 else if ("in".equals(operator)) 
                    predicate = new In(parseFloatArray(operand));
                 else if ("not in".equals(operator)) 
                    predicate = new NotIn(parseFloatArray(operand));
                 else 
                    predicate = null;
                
                final Node node = new Node(featureIndex, predicate);

                if (stack.isEmpty()) 
                    root = node;
                 else 
                    insert(stack, node);
                
                stack.add(node);
             else if (trimmed.startsWith("Predict")) 
                final Matcher m = predictPattern.matcher(trimmed);
                m.matches();
                final Object node = Float.parseFloat(m.group(1));
                insert(stack, node);
            
        
        return root;
    

    private static void insert(final List<Node> stack, final Object node) 
        Node parent = stack.get(stack.size() - 1);
        while (parent.getLeftChild() != null && parent.getRightChild() != null) 
            stack.remove(stack.size() - 1);
            parent = stack.get(stack.size() - 1);
        
        if (parent.getLeftChild() == null) parent.setLeftChild(node);
        else parent.setRightChild(node);
    

    private static float[] parseFloatArray(final String set) 
        final StringTokenizer st = new StringTokenizer(set, ",");
        final float[] floats = new float[st.countTokens()];
        for (int i=0; st.hasMoreTokens(); i++) 
            floats[i] = Float.parseFloat(st.nextToken());
        
        return floats;
    

    public abstract float predict(final float[] features);

    public String toDebugString() 
        try 
            final StringWriter sw = new StringWriter();
            for (int i=0; i<trees.size(); i++) 
                sw.write("Tree " + i + ":\n");
                print(sw, "", trees.get(0));
            
            return sw.toString();
         catch (IOException e) 
            throw new UncheckedIOException(e);
        
    

    private static void print(final Writer w, final String indent, final Object object) throws IOException 
        if (object instanceof Number) 
            w.write(indent + "Predict: " + object + "\n");
         else if (object instanceof Node) 
            final Node node = (Node) object;
            // left node
            w.write(indent + node + "\n");
            print(w, indent + " ", node.getLeftChild());
            w.write(indent + "Else\n");
            print(w, indent + " ", node.getRightChild());
        
    

    @Override
    public String toString() 
        return getClass().getSimpleName() + "numTrees=" + trees.size() + "";
    

    /**
     * Node.
     */
    protected static class Node 

        private final int featureIndex;
        private final Predicate<Float> predicate;
        private Object leftChild;
        private Object rightChild;

        public Node(final int featureIndex, final Predicate<Float> predicate) 
            Objects.requireNonNull(predicate);
            this.featureIndex = featureIndex;
            this.predicate = predicate;
        

        public void setLeftChild(final Object leftChild) 
            this.leftChild = leftChild;
        

        public void setRightChild(final Object rightChild) 
            this.rightChild = rightChild;
        

        public Object getLeftChild() 
            return leftChild;
        

        public Object getRightChild() 
            return rightChild;
        

        public Object eval(final float[] features) 
            Object result = this;
            do 
                final Node node = (Node)result;
                result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
             while (result instanceof Node);

            return result;
        

        @Override
        public String toString() 
            return "If (feature " + featureIndex + " " + predicate + ")";
        

    

    private static class LessOrEqual implements Predicate<Float> 
        private final float value;

        public LessOrEqual(final float value) 
            this.value = value;
        

        @Override
        public boolean test(final Float f) 
            return f <= value;
        

        @Override
        public String toString() 
            return "<= " + value;
        
    

    private static class Greater implements Predicate<Float> 
        private final float value;

        public Greater(final float value) 
            this.value = value;
        

        @Override
        public boolean test(final Float f) 
            return f > value;
        

        @Override
        public String toString() 
            return "> " + value;
        
    

    private static class In implements Predicate<Float> 
        private final float[] array;

        public In(final float[] array) 
            this.array = array;
        

        @Override
        public boolean test(final Float f) 
            for (int i=0; i<array.length; i++) 
                if (array[i] == f) return true;
            
            return false;
        

        @Override
        public String toString() 
            return "in " + Arrays.toString(array);
        
    

    private static class NotIn implements Predicate<Float> 
        private final float[] array;

        public NotIn(final float[] array) 
            this.array = array;
        

        @Override
        public boolean test(final Float f) 
            for (int i=0; i<array.length; i++) 
                if (array[i] == f) return false;
            
            return true;
        

        @Override
        public String toString() 
            return "not in " + Arrays.toString(array);
        
    

要使用类进行分类,请使用:

import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

/**
 * RandomForestClassifier.
 */
public class RandomForestClassifier extends RandomForest 

    public RandomForestClassifier(final URL model) throws IOException 
        super(model);
    

    @Override
    public float predict(final float[] features) 
        final Map<Object, Integer> counts = new HashMap<>();
        trees.stream().map(node -> node.eval(features))
                .forEach(result -> 
                    Integer count = counts.get(result);
                    if (count == null) 
                        counts.put(result, 1);
                     else 
                        counts.put(result, count + 1);
                    
                );
        return (Float)counts.entrySet()
                .stream()
                .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
                .map(Map.Entry::getKey)
                .findFirst().get();
    

对于回归:

import java.io.IOException;
import java.net.URL;

/**
 * RandomForestRegressor.
 */
public class RandomForestRegressor extends RandomForest 

    public RandomForestRegressor(final URL model) throws IOException 
        super(model);
    

    @Override
    public float predict(final float[] features) 
        return (float)trees
                .stream()
                .mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
                .average()
                .getAsDouble();
    

【讨论】:

以上是关于Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString的主要内容,如果未能解决你的问题,请参考以下文章

4.Spark ML学习笔记—Spark ML决策树 (应用案例)随机森林GBDT算法ML 树模型参数详解 (本篇概念多)

Spark 随机森林交叉验证错误

带有随机森林的 Spark 流程数据框

如何在 Spark 中处理最新的随机森林中的分类特征?

数据结构-集成算法-随机森林

数据结构-集成算法-随机森林