2016-03-03 89 views
1

我有一個函數返回一個實例的數值,我後來使用這個數值將實例分類到三個類別之一。類別相對可分,見下圖(三種顏色代表三種不同的類別)。 Histogram over occurrences of different classes爲簡單閾值分類器找到mutliclass閾值

所以在這裏我想兩個閾值,​​和k2,使一切離開的​​是分類爲紅色,一切權利的k2分類藍色,一切都在中間被列爲綠色。

我開始使用基於this解決方案的Kadane算法的修改版本。我首先按照它們的值對所有(顏色,值)元組進行排序,然後生成一個數組,其中所有綠色分類的值均爲1,非綠色的值爲-1。所以,我會得到一個看起來像這樣的數組:

[-1, -1, -1, -1, 1, -1, -1, ..., 1, 1, 1, -1, 1, ..., -1, -1, -1, -1] 

也就是說,最初有很多-1(紅色),則圍在中間有很多綠色的,並接近尾聲了大多是藍調。現在,通過運行Kadane算法,我會得到最優分割嗎?

這裏是我測試的代碼:

import java.util.*; 

public class Kadanes { 
    private static Color[] correctClasses = new Color[]{Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN}; 
    private static double[] predictedValues = new double[]{0.0, 0.34, 2.0, 2.67, 7.53, -0.04, 2.0, -3.55, 3.78, 0.33, 3.0, -0.21, 1.41, -0.37, 0.84, 3.94, 8.34, 0.0, -1.39, 3.0, -1.63, 0.0, 3.0, 1.26, 0.0, 0.0, 0.0, 0.0, 0.61, 0.0, 3.34, 0.57, -1.05, 0.63, 0.0, 0.71, 0.0, 2.34, -0.41, -1.77, 3.0, 0.62, 0.93, 1.55, 2.0, 8.0, -1.55, 5.75, 0.0, 0.0, -0.25, 0.0, 1.0, 10.51, 0.0, 0.47, 0.78, -1.08, -1.51, 1.0, 1.0, 0.0, 4.33, -0.6, 0.37, 6.0, 1.16, -4.07, 2.0, 0.91, -0.05, 1.78, 0.0, 0.0, 0.0, 0.0, 0.0, 1.64, 1.55, 4.44, 2.78, 1.47, 3.75, 0.0, 7.59, 0.0, 0.94, 2.46, -0.23, -0.2, 0.0, 0.39, -2.31, 3.0, -1.15, 2.0, -0.76, -1.33, 0.0, 0.61, 0.77, -1.77, -1.08, 0.0, -3.2, 3.46, 1.0, 0.0, 0.0, 3.33, 0.0, 0.0, 2.81, 0.0, 0.0, 0.0, 3.0, 0.0, -0.88, 1.65, -1.09, -0.35, 0.0, 0.0, 5.0, 0.0, 2.88, -0.72, 0.87, 7.0, 7.48, -1.98, 1.0, 1.11, 4.0, 1.53, 0.0, 8.07, 1.54, 4.23, 0.0, -0.73, 6.61, 0.07, 0.0, -4.32, -1.77, 2.05, -1.08, 4.3, 1.61, 2.96, 3.0, 0.0, 3.66, 0.0, 0.0, 0.05, -0.77, -1.0, 0.0, 5.43, 2.12, -1.55, 2.3, 0.0, 3.6, 0.0, 0.0, -10.21, 2.0, 0.55, -0.63, 0.0, 1.0, 0.0, 0.0, 1.28, 3.0, 0.0, 0.44, 1.27, 2.12, 2.17, 1.76, -1.9, 5.42, 1.0, 3.76, -3.55, -0.82, 0.0, 0.11, -1.7, -0.33, 0.0, 0.0, -2.01, 0.0, 3.52, 2.0, 6.0, 0.92, 7.22, 0.0, 0.0, 0.0, 0.0, 0.36, -1.77, 0.0, -3.32, -0.91, 2.69, -0.86, -0.27, 3.28, -1.02, 0.41, -0.6, 2.61, 0.0, 0.36, 0.0, 0.91, 0.0, -2.82, 0.0, -1.77, 0.0, -0.33, 3.94, -2.55, 8.0, 3.29, 2.7, -4.4, 9.0, 0.0, 2.81, -0.23, -2.51, 2.0, -0.19, 0.0, 0.0, 0.0, 0.8, 8.33, 0.0, 0.59, 0.0, 0.41, 0.0, 0.8, 1.7, 3.27, 0.0, 0.34, -1.83, 0.0, -1.0, 0.29, 3.71, -0.44, -0.59, 1.25, 2.3, -1.56, 0.0, 6.21, -0.68, 0.0, 0.0, -0.3, 0.0, 1.0, 0.86, 0.0, 0.0, 0.0, 0.0, 0.41, 1.91, -0.17, -0.77, 1.0, 3.0, 2.0, 3.0, -0.71, 0.0, 0.62, 0.0, 2.54, 1.14, 0.0, 0.0, 3.27, 0.0, 0.96, -0.33, 0.0, 0.0, 1.91, -0.2, 0.0, 0.0, 0.6, 0.0, -0.82, 1.0, -0.54, 6.52, -2.48, 2.0, 0.0, 0.0, 1.61, 0.0, 0.0, 0.0, -0.17, 0.0, 1.0, -5.36, 2.73, 0.0, 7.97, 3.67, 0.0, -0.88, 0.93, 0.0, 3.0, -1.03, -0.64, 2.78, 0.0, 1.0, 3.0, 0.0, 0.46, 0.0, -0.63, 0.0, 4.0, 4.0, 1.61, 0.0, 0.0, 1.07, 0.0, 1.0, 18.39, -1.82, 0.0, 0.86, -0.42, -1.77, -0.61, 0.0, 0.68, -3.13, 0.53, 0.0, 3.0, 0.0, 2.47, 0.0, -1.74, 5.31, 0.0, 0.3, 0.0, 0.0, 4.0, 1.0, 0.64, 1.0, 0.0, -1.77, 3.31, -1.77, -0.43, -3.55, 0.94, 8.59, 0.0, 1.81, 3.69, -1.77, -0.32, 0.0, 3.0, 1.93, -1.47, 1.0, 3.21, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0, -0.39, 0.0, 1.0, 0.0, 1.98, 0.0, 0.0, 7.45, 0.72, 0.34, 0.0, 0.35, 0.0, -2.74, 0.28, 4.0, 3.0, -0.91, -4.43, 0.0, 2.28, 3.0, -2.5, -2.66, 2.0, -0.66, 3.0, 11.06, 1.43, 3.0, 0.0, -0.79, 6.3, 0.94, 3.92, -4.43, 5.14, -2.35, 8.83, 1.04, 2.6, 5.0, 3.72}; 

    private static List<Tuple> previousResults = new ArrayList<>(); 
    static { 
     for(int i=0; i<correctClasses.length; i++) { 
      previousResults.add(new Tuple(correctClasses[i], predictedValues[i])); 
     } 
    } 


    public static void main(String[] args) { 
     double[] exampleThresholds = new double[]{-1.65, 1.65}; 
     double[] thresholds = getThreshold(); 
     System.out.println(Arrays.toString(thresholds)); 

     System.out.println("Example threshold accuracy: " + getAccuracy(exampleThresholds)); 
     System.out.println("Optimal threshold accuracy: " + getAccuracy(thresholds)); 
    } 


    private static double[] getThreshold() { 
     Collections.sort(previousResults, Collections.reverseOrder()); 

     int max_so_far = 0; 
     int max_ending_here = 0; 
     int max_start_index = 0; 
     int startIndex = 0; 
     int max_end_index = -1; 

     for(int i = 0; i < previousResults.size(); i++) { 
      int currentElementScore = (previousResults.get(i).correct == Color.GREEN ? 1 : -1); 
      if(max_ending_here + currentElementScore < 0) { 
       startIndex = i+1; 
       max_ending_here = 0; 
      } else { 
       max_ending_here += currentElementScore; 
      } 

      if(max_ending_here > max_so_far) { 
       max_so_far = max_ending_here; 
       max_start_index = startIndex; 
       max_end_index = i; 
      } 
     } 

     double lowThreshold = getAvgValue(max_start_index-1, max_start_index); 
     double highThreshold = getAvgValue(max_end_index, max_end_index+1); 

     return new double[]{lowThreshold, highThreshold}; 
    } 


    private static double getAccuracy(double[] thresholds) { 
     int numCorrectlyClassified = 0; 
     for(int i=0; i<correctClasses.length; i++) { 
      Color predictedClassification = classify(predictedValues[i], thresholds[0], thresholds[1]); 
      if(predictedClassification == correctClasses[i]) { 
       numCorrectlyClassified++; 
      } 
     } 

     return (double) numCorrectlyClassified/correctClasses.length; 
    } 

    private static Color classify(double value, double lowThresh, double highThresh) { 
     if(value < lowThresh) return Color.RED; 
     if(value > highThresh) return Color.BLUE; 
     return Color.GREEN; 
    } 


    private static double getAvgValue(int index1, int index2) { 
     if(index1 < 0) { 
      return Double.NEGATIVE_INFINITY; 
     } else if (index2 >= previousResults.size()) { 
      return Double.POSITIVE_INFINITY; 
     } 

     return (previousResults.get(index1).predicted + previousResults.get(index2).predicted)/2; 
    } 


    static class Tuple implements Comparable<Tuple> { 
     private Color correct; 
     private double predicted; 

     Tuple(Color correct, double predicted) { 
      this.correct = correct; 
      this.predicted = predicted; 
     } 

     public String toString() { 
      return "[" + correct.name() + ", " + predicted + "]"; 
     } 

     @Override 
     public int compareTo(Tuple o) { 
      double diff = o.predicted - predicted; 
      return diff != 0 ? (int) Math.signum(diff) : correct.compareTo(o.correct); 
     } 
    } 

    enum Color { 
     BLUE, GREEN, RED 
    } 
} 

輸出我得到的是:

[0.0, 0.0] 
Example threshold accuracy: 0.5602678571428571 
Optimal threshold accuracy: 0.49107142857142855 

所以它找到的最佳閾值就是0.0範圍內,我進入只是一個快速示例閾值表現更好。實現是錯誤的還是不可能使用Kadane的算法來解決這個簡單的問題,如果不是的話,我可以使用哪種算法?

+0

我認爲你需要定義「最優性」的含義。 –

+0

@GordonLinoff最大準確度 – Limon

回答

0

YOu無法使用Kadane算法解決此問題,因爲它優化了所採用的果嶺數量。 假設你有這樣的:

1, 1, -1, 1, -1, 1,.., -1, 1, 1, 1 

該算法將要最大化的總和,並採取先者和最後3,因爲總和在中間部分就是-1。

所以相反,我會使用N logN算法來查找閾值。 首先對數組進行排序。 用0和x之間的綠色,紅色和藍色數(對於所有x)預先計算部分計數數組。 通過排序後的數組迭代第一個閾值,使用精度作爲指導性度量,二分搜索第二個最佳位置。要計算準確度,您可以使用預先計算的部分計數。您需要了解指標增加的方式。

可以有幾個角落案例。如果您可以負擔N^2,只需嘗試所有閾值並使用預計算陣列加速評估。

+0

我認爲Kadane算法可行的原因是因爲範圍有點可分,紅色和藍色之間幾乎沒有重疊,所以不會優化所採取的綠色數量,還會優化其他人嗎? 我不太理解你的第二段,Kadane挑選了一個範圍,所以即使中間部分的總和爲0,起始部分和結束部分都大於0,所以它會包含整個範圍? – Limon

+1

@Limon是的,我的觀點是,因爲中間部分是0,解決方案將包括整個範圍,這可能不是你想要的。 – Sorin