0
我搜索了一下而沒有找到任何東西 - 寫了一個可以和Scala,Spark(MLlib)一起使用的StratifiedKFold方法。我在下面發佈答案如何在Scala Spark中獲得StratifiedKFold MLLib
我搜索了一下而沒有找到任何東西 - 寫了一個可以和Scala,Spark(MLlib)一起使用的StratifiedKFold方法。我在下面發佈答案如何在Scala Spark中獲得StratifiedKFold MLLib
def StratifiedKFold(nSamples: Int, k: Int, labels: List[Int],shuffle: Boolean = false): (Map[Int,List[List[Int]]],Int)= {
var idxs = (0 until nSamples).toArray
val unqLabels = labels.distinct
val noOfLabels = unqLabels.length
val idxsbylabel = idxs.groupBy { x => labels(x) }
var stratifiedidxs: Map[Int,List[List[Int]]] = Map(1 -> List(List(1)))
for (i <- 0 to noOfLabels-1){
val labelsgroup_i_arr = if(shuffle) bshuffle(idxsbylabel(i).toArray) else idxsbylabel(i).toArray
val noOfParts = if(labelsgroup_i_arr.length%k==0) labelsgroup_i_arr.length/k else (labelsgroup_i_arr.length/k)+1
val labelsgroup_i_lst = List.concat(labelsgroup_i_arr)
stratifiedidxs = stratifiedidxs + (i -> labelsgroup_i_lst.grouped(noOfParts).toList)
}
(stratifiedidxs,noOfLabels)
}
JavaRDD的stratifiedKFold的java版本。
/*
stritified KFolds. this will promise each category will be fairly splited into training and testing samples.
for example 3 folds, category '0' has 3 samples, then all the 3 output folds will be like: training with 2 samples, testing with 1 samples.
the split will be as fair as possible, like: to split 11 samples into 3, result will be [4, 4, 3].
if the count of one category is less than k, then some folds will not contain samples of this category, this bad split will cause inaccurate result. but this is not mandatory.
*/
private static Tuple2<JavaRDD<LabeledPoint>, JavaRDD<LabeledPoint>>[] stritifyKFolds(JavaRDD<LabeledPoint> data, Broadcast<Integer> kBC){
JavaRDD<List<List<LabeledPoint>>> foldedLP = data.mapToPair(//map to (key,list) pair.
lp -> {
List<LabeledPoint> list = new ArrayList<>();
list.add(lp);
return new Tuple2<>(lp.label(), list);
}
).reduceByKey(//aggregate
(List<LabeledPoint> list1, List<LabeledPoint> list2) -> {
list1.addAll(list2);
return list1; //put the LabeledPoint with the same key into one list.
}
).values() //get list only after aggregate.
.map(//split each list into K folds.
list -> {
//shuffule and then put into different folds.
Collections.shuffle(list);
int total = list.size();
if(total < kBC.value()){
log.warn("category size {} is less than folds number {}, this will break stratification and leads to bad folds splitting", total, kBC.value());
}
//assign each element into folds.
List<List<LabeledPoint>> keyFolds = new ArrayList<>();
int avg = total/kBC.value(); //averge number of elements.
int remain = total - kBC.value(); //remain number of elements, which will be assign to folds from left to right fairly.
for(int i=0, index = 0, count=0; i<kBC.value(); i++, index += count){
//get current folds count
count = (i<remain) ? avg+1 : avg;
keyFolds.add(list.subList(index, index+count));
}
return keyFolds;
}
);
foldedLP.persist(StorageLevel.MEMORY_AND_DISK_SER());
Tuple2<JavaRDD<LabeledPoint>, JavaRDD<LabeledPoint>>[] result = new Tuple2[kBC.value()];
//initialize folds, each fold is a tuple of (training RDD, testing RDD).
for(int i=0; i<kBC.value(); i++){
final int ii = i; //must be final variable if used in lambda.
JavaRDD<LabeledPoint> training_i = foldedLP.flatMap(
list -> {
List<LabeledPoint> noII = new ArrayList<>();
for(int j=0; j<kBC.value(); j++){
if(j != ii){
noII.addAll(list.get(j)); //get the all except ii_th list iterator.
}
}
return noII.iterator();
}
);
JavaRDD<LabeledPoint> testing_i = foldedLP.flatMap(
list -> list.get(ii).iterator() //get the ii_th list iterator.
);
result[i] = new Tuple2<>(training_i, testing_i);
}
foldedLP.unpersist();
return result;
}