2016-12-16 115 views
1

使用星火1.6和ML庫我使用toDebugString()節約了訓練的RandomForestClassificationModel結果:Spark | ML |隨機森林|從RandomForestClassificationModel的.txt加載訓練好的模型。 toDebugString

val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] 
val stringModel =rfModel.toDebugString 
//save stringModel into a file in the driver in format .txt 

所以我的想法是,在未來的讀取文件.txt並加載訓練有素的隨機森林,是有可能?

謝謝!

回答

0

這是行不通的。 ToDebugString僅僅是一個調試信息,以瞭解它是如何計算的。

如果你想保留這個東西以備以後使用,你可以做我們做的一樣,這是(雖然我們是用純Java)只是連載RandomForestModel對象。有可能是版本與默認的Java序列化不兼容,所以我們使用Hessian來做到這一點。它通過版本更新工作 - 我們從spark 1.6.1開始,它仍然適用於spark 2.0.2。

0

如果你不堅持ml,可以使用mllib的實現:你用mllib得到的RandomForestModelsave函數。

0

至少在星火2.1.0你可以用下面的Java做到這一點(對不起 - 沒有斯卡拉)代碼。然而,依靠未經註明的未經註明的格式可能並不是最聰明的想法。

import org.slf4j.Logger; 
import org.slf4j.LoggerFactory; 

import java.io.*; 
import java.net.URL; 
import java.util.*; 
import java.util.function.Predicate; 
import java.util.regex.Matcher; 
import java.util.regex.Pattern; 

import static java.nio.charset.StandardCharsets.US_ASCII; 

/** 
* RandomForest. 
*/ 
public abstract class RandomForest { 

    private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class); 

    protected final List<Node> trees = new ArrayList<>(); 

    /** 
    * @param model model file (format is Spark's RandomForestClassificationModel toDebugString()) 
    * @throws IOException 
    */ 
    public RandomForest(final URL model) throws IOException { 
     try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) { 
      Node node; 
      while ((node = load(reader)) != null) { 
       trees.add(node); 
      } 
     } 
     if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model); 
     if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees."); 
    } 

    private static Node load(final BufferedReader reader) throws IOException { 
     final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)"); 
     final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)"); 
     Node root = null; 
     final List<Node> stack = new ArrayList<>(); 
     String line; 
     while ((line = reader.readLine()) != null) { 
      final String trimmed = line.trim(); 
      //System.out.println(trimmed); 
      if (trimmed.startsWith("RandomForest")) { 
       // skip the "Tree 1" line 
       reader.readLine(); 
      } else if (trimmed.startsWith("Tree")) { 
       break; 
      } else if (trimmed.startsWith("If")) { 
       // extract feature index 
       final Matcher m = ifPattern.matcher(trimmed); 
       m.matches(); 
       final int featureIndex = Integer.parseInt(m.group(1)); 
       final String operator = m.group(2); 
       final String operand = m.group(3); 
       final Predicate<Float> predicate; 
       if ("<=".equals(operator)) { 
        predicate = new LessOrEqual(Float.parseFloat(operand)); 
       } else if (">".equals(operator)) { 
        predicate = new Greater(Float.parseFloat(operand)); 
       } else if ("in".equals(operator)) { 
        predicate = new In(parseFloatArray(operand)); 
       } else if ("not in".equals(operator)) { 
        predicate = new NotIn(parseFloatArray(operand)); 
       } else { 
        predicate = null; 
       } 
       final Node node = new Node(featureIndex, predicate); 

       if (stack.isEmpty()) { 
        root = node; 
       } else { 
        insert(stack, node); 
       } 
       stack.add(node); 
      } else if (trimmed.startsWith("Predict")) { 
       final Matcher m = predictPattern.matcher(trimmed); 
       m.matches(); 
       final Object node = Float.parseFloat(m.group(1)); 
       insert(stack, node); 
      } 
     } 
     return root; 
    } 

    private static void insert(final List<Node> stack, final Object node) { 
     Node parent = stack.get(stack.size() - 1); 
     while (parent.getLeftChild() != null && parent.getRightChild() != null) { 
      stack.remove(stack.size() - 1); 
      parent = stack.get(stack.size() - 1); 
     } 
     if (parent.getLeftChild() == null) parent.setLeftChild(node); 
     else parent.setRightChild(node); 
    } 

    private static float[] parseFloatArray(final String set) { 
     final StringTokenizer st = new StringTokenizer(set, "{,}"); 
     final float[] floats = new float[st.countTokens()]; 
     for (int i=0; st.hasMoreTokens(); i++) { 
      floats[i] = Float.parseFloat(st.nextToken()); 
     } 
     return floats; 
    } 

    public abstract float predict(final float[] features); 

    public String toDebugString() { 
     try { 
      final StringWriter sw = new StringWriter(); 
      for (int i=0; i<trees.size(); i++) { 
       sw.write("Tree " + i + ":\n"); 
       print(sw, "", trees.get(0)); 
      } 
      return sw.toString(); 
     } catch (IOException e) { 
      throw new UncheckedIOException(e); 
     } 
    } 

    private static void print(final Writer w, final String indent, final Object object) throws IOException { 
     if (object instanceof Number) { 
      w.write(indent + "Predict: " + object + "\n"); 
     } else if (object instanceof Node) { 
      final Node node = (Node) object; 
      // left node 
      w.write(indent + node + "\n"); 
      print(w, indent + " ", node.getLeftChild()); 
      w.write(indent + "Else\n"); 
      print(w, indent + " ", node.getRightChild()); 
     } 
    } 

    @Override 
    public String toString() { 
     return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}"; 
    } 

    /** 
    * Node. 
    */ 
    protected static class Node { 

     private final int featureIndex; 
     private final Predicate<Float> predicate; 
     private Object leftChild; 
     private Object rightChild; 

     public Node(final int featureIndex, final Predicate<Float> predicate) { 
      Objects.requireNonNull(predicate); 
      this.featureIndex = featureIndex; 
      this.predicate = predicate; 
     } 

     public void setLeftChild(final Object leftChild) { 
      this.leftChild = leftChild; 
     } 

     public void setRightChild(final Object rightChild) { 
      this.rightChild = rightChild; 
     } 

     public Object getLeftChild() { 
      return leftChild; 
     } 

     public Object getRightChild() { 
      return rightChild; 
     } 

     public Object eval(final float[] features) { 
      Object result = this; 
      do { 
       final Node node = (Node)result; 
       result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild; 
      } while (result instanceof Node); 

      return result; 
     } 

     @Override 
     public String toString() { 
      return "If (feature " + featureIndex + " " + predicate + ")"; 
     } 

    } 

    private static class LessOrEqual implements Predicate<Float> { 
     private final float value; 

     public LessOrEqual(final float value) { 
      this.value = value; 
     } 

     @Override 
     public boolean test(final Float f) { 
      return f <= value; 
     } 

     @Override 
     public String toString() { 
      return "<= " + value; 
     } 
    } 

    private static class Greater implements Predicate<Float> { 
     private final float value; 

     public Greater(final float value) { 
      this.value = value; 
     } 

     @Override 
     public boolean test(final Float f) { 
      return f > value; 
     } 

     @Override 
     public String toString() { 
      return "> " + value; 
     } 
    } 

    private static class In implements Predicate<Float> { 
     private final float[] array; 

     public In(final float[] array) { 
      this.array = array; 
     } 

     @Override 
     public boolean test(final Float f) { 
      for (int i=0; i<array.length; i++) { 
       if (array[i] == f) return true; 
      } 
      return false; 
     } 

     @Override 
     public String toString() { 
      return "in " + Arrays.toString(array); 
     } 
    } 

    private static class NotIn implements Predicate<Float> { 
     private final float[] array; 

     public NotIn(final float[] array) { 
      this.array = array; 
     } 

     @Override 
     public boolean test(final Float f) { 
      for (int i=0; i<array.length; i++) { 
       if (array[i] == f) return false; 
      } 
      return true; 
     } 

     @Override 
     public String toString() { 
      return "not in " + Arrays.toString(array); 
     } 
    } 
} 

用於分類的類,使用方法:

import java.io.IOException; 
import java.net.URL; 
import java.util.HashMap; 
import java.util.Map; 

/** 
* RandomForestClassifier. 
*/ 
public class RandomForestClassifier extends RandomForest { 

    public RandomForestClassifier(final URL model) throws IOException { 
     super(model); 
    } 

    @Override 
    public float predict(final float[] features) { 
     final Map<Object, Integer> counts = new HashMap<>(); 
     trees.stream().map(node -> node.eval(features)) 
       .forEach(result -> { 
        Integer count = counts.get(result); 
        if (count == null) { 
         counts.put(result, 1); 
        } else { 
         counts.put(result, count + 1); 
        } 
       }); 
     return (Float)counts.entrySet() 
       .stream() 
       .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue())) 
       .map(Map.Entry::getKey) 
       .findFirst().get(); 
    } 
} 

對於迴歸:

import java.io.IOException; 
import java.net.URL; 

/** 
* RandomForestRegressor. 
*/ 
public class RandomForestRegressor extends RandomForest { 

    public RandomForestRegressor(final URL model) throws IOException { 
     super(model); 
    } 

    @Override 
    public float predict(final float[] features) { 
     return (float)trees 
       .stream() 
       .mapToDouble(node -> ((Number)node.eval(features)).doubleValue()) 
       .average() 
       .getAsDouble(); 
    } 
} 
相關問題