至少在星火2.1.0你可以用下面的Java做到這一點(對不起 - 沒有斯卡拉)代碼。然而,依靠未經註明的未經註明的格式可能並不是最聰明的想法。
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static java.nio.charset.StandardCharsets.US_ASCII;
/**
* RandomForest.
*/
public abstract class RandomForest {
private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);
protected final List<Node> trees = new ArrayList<>();
/**
* @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
* @throws IOException
*/
public RandomForest(final URL model) throws IOException {
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
Node node;
while ((node = load(reader)) != null) {
trees.add(node);
}
}
if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
}
private static Node load(final BufferedReader reader) throws IOException {
final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
Node root = null;
final List<Node> stack = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
final String trimmed = line.trim();
//System.out.println(trimmed);
if (trimmed.startsWith("RandomForest")) {
// skip the "Tree 1" line
reader.readLine();
} else if (trimmed.startsWith("Tree")) {
break;
} else if (trimmed.startsWith("If")) {
// extract feature index
final Matcher m = ifPattern.matcher(trimmed);
m.matches();
final int featureIndex = Integer.parseInt(m.group(1));
final String operator = m.group(2);
final String operand = m.group(3);
final Predicate<Float> predicate;
if ("<=".equals(operator)) {
predicate = new LessOrEqual(Float.parseFloat(operand));
} else if (">".equals(operator)) {
predicate = new Greater(Float.parseFloat(operand));
} else if ("in".equals(operator)) {
predicate = new In(parseFloatArray(operand));
} else if ("not in".equals(operator)) {
predicate = new NotIn(parseFloatArray(operand));
} else {
predicate = null;
}
final Node node = new Node(featureIndex, predicate);
if (stack.isEmpty()) {
root = node;
} else {
insert(stack, node);
}
stack.add(node);
} else if (trimmed.startsWith("Predict")) {
final Matcher m = predictPattern.matcher(trimmed);
m.matches();
final Object node = Float.parseFloat(m.group(1));
insert(stack, node);
}
}
return root;
}
private static void insert(final List<Node> stack, final Object node) {
Node parent = stack.get(stack.size() - 1);
while (parent.getLeftChild() != null && parent.getRightChild() != null) {
stack.remove(stack.size() - 1);
parent = stack.get(stack.size() - 1);
}
if (parent.getLeftChild() == null) parent.setLeftChild(node);
else parent.setRightChild(node);
}
private static float[] parseFloatArray(final String set) {
final StringTokenizer st = new StringTokenizer(set, "{,}");
final float[] floats = new float[st.countTokens()];
for (int i=0; st.hasMoreTokens(); i++) {
floats[i] = Float.parseFloat(st.nextToken());
}
return floats;
}
public abstract float predict(final float[] features);
public String toDebugString() {
try {
final StringWriter sw = new StringWriter();
for (int i=0; i<trees.size(); i++) {
sw.write("Tree " + i + ":\n");
print(sw, "", trees.get(0));
}
return sw.toString();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
private static void print(final Writer w, final String indent, final Object object) throws IOException {
if (object instanceof Number) {
w.write(indent + "Predict: " + object + "\n");
} else if (object instanceof Node) {
final Node node = (Node) object;
// left node
w.write(indent + node + "\n");
print(w, indent + " ", node.getLeftChild());
w.write(indent + "Else\n");
print(w, indent + " ", node.getRightChild());
}
}
@Override
public String toString() {
return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
}
/**
* Node.
*/
protected static class Node {
private final int featureIndex;
private final Predicate<Float> predicate;
private Object leftChild;
private Object rightChild;
public Node(final int featureIndex, final Predicate<Float> predicate) {
Objects.requireNonNull(predicate);
this.featureIndex = featureIndex;
this.predicate = predicate;
}
public void setLeftChild(final Object leftChild) {
this.leftChild = leftChild;
}
public void setRightChild(final Object rightChild) {
this.rightChild = rightChild;
}
public Object getLeftChild() {
return leftChild;
}
public Object getRightChild() {
return rightChild;
}
public Object eval(final float[] features) {
Object result = this;
do {
final Node node = (Node)result;
result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
} while (result instanceof Node);
return result;
}
@Override
public String toString() {
return "If (feature " + featureIndex + " " + predicate + ")";
}
}
private static class LessOrEqual implements Predicate<Float> {
private final float value;
public LessOrEqual(final float value) {
this.value = value;
}
@Override
public boolean test(final Float f) {
return f <= value;
}
@Override
public String toString() {
return "<= " + value;
}
}
private static class Greater implements Predicate<Float> {
private final float value;
public Greater(final float value) {
this.value = value;
}
@Override
public boolean test(final Float f) {
return f > value;
}
@Override
public String toString() {
return "> " + value;
}
}
private static class In implements Predicate<Float> {
private final float[] array;
public In(final float[] array) {
this.array = array;
}
@Override
public boolean test(final Float f) {
for (int i=0; i<array.length; i++) {
if (array[i] == f) return true;
}
return false;
}
@Override
public String toString() {
return "in " + Arrays.toString(array);
}
}
private static class NotIn implements Predicate<Float> {
private final float[] array;
public NotIn(final float[] array) {
this.array = array;
}
@Override
public boolean test(final Float f) {
for (int i=0; i<array.length; i++) {
if (array[i] == f) return false;
}
return true;
}
@Override
public String toString() {
return "not in " + Arrays.toString(array);
}
}
}
用於分類的類,使用方法:
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
/**
* RandomForestClassifier.
*/
public class RandomForestClassifier extends RandomForest {
public RandomForestClassifier(final URL model) throws IOException {
super(model);
}
@Override
public float predict(final float[] features) {
final Map<Object, Integer> counts = new HashMap<>();
trees.stream().map(node -> node.eval(features))
.forEach(result -> {
Integer count = counts.get(result);
if (count == null) {
counts.put(result, 1);
} else {
counts.put(result, count + 1);
}
});
return (Float)counts.entrySet()
.stream()
.sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
.map(Map.Entry::getKey)
.findFirst().get();
}
}
對於迴歸:
import java.io.IOException;
import java.net.URL;
/**
* RandomForestRegressor.
*/
public class RandomForestRegressor extends RandomForest {
public RandomForestRegressor(final URL model) throws IOException {
super(model);
}
@Override
public float predict(final float[] features) {
return (float)trees
.stream()
.mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
.average()
.getAsDouble();
}
}