2016-03-09 34 views
1

我已經應用KNN算法對手寫數字進行分類。這些數字最初是8 * 8的矢量格式,並且被拉伸以形成矢量1 * 64 ..在kNN算法中改變k的值 - Java

因爲它代表我應用kNN算法,但只使用k = 1。我不完全知道如何在嘗試了幾個我不斷拋出錯誤的東西后,改變k值。如果任何人都可以幫助我朝着正確的方向前進,那麼我將非常感激。訓練數據集可以找到here和驗證集合here

ImageMatrix.java

import java.util.*; 

public class ImageMatrix { 
    private int[] data; 
    private int classCode; 
    private int curData; 
public ImageMatrix(int[] data, int classCode) { 
    assert data.length == 64; //maximum array length of 64 
    this.data = data; 
    this.classCode = classCode; 
} 

    public String toString() { 
     return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable 
    } 

    public int[] getData() { 
     return data; 
    } 

    public int getClassCode() { 
     return classCode; 
    } 
    public int getCurData() { 
     return curData; 
    } 



} 

ImageMatrixDB.java

import java.util.*; 
import java.io.*; 
import java.util.ArrayList; 
public class ImageMatrixDB implements Iterable<ImageMatrix> { 
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>(); 

    public ImageMatrixDB load(String f) throws IOException { 
     try (
      FileReader fr = new FileReader(f); 
      BufferedReader br = new BufferedReader(fr)) { 
      String line = null; 

      while((line = br.readLine()) != null) { 
       int lastComma = line.lastIndexOf(','); 
       int classCode = Integer.parseInt(line.substring(1 + lastComma)); 
       int[] data = Arrays.stream(line.substring(0, lastComma).split(",")) 
            .mapToInt(Integer::parseInt) 
            .toArray(); 
       ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9.. 
       list.add(matrix); 
      } 
     } 
     return this; 
    } 

    public void printResults(){ //output results 
     for(ImageMatrix matrix: list){ 
      System.out.println(matrix); 
     } 
    } 


    public Iterator<ImageMatrix> iterator() { 
     return this.list.iterator(); 
    } 

    /// kNN implementation /// 
    public static int distance(int[] a, int[] b) { 
     int sum = 0; 
     for(int i = 0; i < a.length; i++) { 
      sum += (a[i] - b[i]) * (a[i] - b[i]); 
     } 
     return (int)Math.sqrt(sum); 
    } 


    public static int classify(ImageMatrixDB trainingSet, int[] curData) { 
     int label = 0, bestDistance = Integer.MAX_VALUE; 
     for(ImageMatrix matrix: trainingSet) { 
      int dist = distance(matrix.getData(), curData); 
      if(dist < bestDistance) { 
       bestDistance = dist; 
       label = matrix.getClassCode(); 
      } 
     } 
     return label; 
    } 


    public int size() { 

     return list.size(); //returns size of the list 

     } 


    public static void main(String[] argv) throws IOException { 
     ImageMatrixDB trainingSet = new ImageMatrixDB(); 
     ImageMatrixDB validationSet = new ImageMatrixDB(); 
     trainingSet.load("cw2DataSet1.csv"); 
     validationSet.load("cw2DataSet2.csv"); 
     int numCorrect = 0; 
     for(ImageMatrix matrix:validationSet) { 
      if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++; 
     } //285 correct 
     System.out.println("Accuracy: " + (double)numCorrect/validationSet.size() * 100 + "%"); 
     System.out.println(); 
    } 
+0

雖然你的問題是'classify'方法,我不認爲這是一個好主意,使用歐幾里德距離的圖像。一旦你舒展他們,你失去了相關信息。例如,屬於具有不同背景顏色的同一人的兩幅圖像會導致較高的歐幾里得距離 – MGoksu

回答

2

在爲環分類你正在努力尋找訓練的例子,是最接近測試點。您需要將代碼切換爲最接近測試數據的訓練點的K。然後你應該爲這些K點中的每一個調用getClassCode,並找出其中多數類(即最常見的)類代碼。 classify將返回您找到的主要類代碼。

您可以以任何適合您需要的方式打破關係(即擁有分配給相同數量訓練數據的2+最頻繁分類代碼)。

我真的沒有經驗的Java,但只是通過環顧語言參考,我想出了下面的實現。

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) { 
    int label = 0, bestDistance = Integer.MAX_VALUE; 
    int[][] distances = new int[trainingSet.size()][2]; 
    int i=0; 

    // Place distances in an array to be sorted 
    for(ImageMatrix matrix: trainingSet) { 
     distances[i][0] = distance(matrix.getData(), curData); 
     distances[i][1] = matrix.getClassCode(); 
     i++; 
    } 

    Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]); 

    // Find frequencies of each class code 
    i = 0; 
    Map<Integer,Integer> majorityMap; 
    majorityMap = new HashMap<Integer,Integer>(); 
    while(i < k) { 
     if(majorityMap.containsKey(distances[i][1])) { 
      int currentValue = majorityMap.get(distances[i][1]); 
      majorityMap.put(distances[i][1], currentValue + 1); 
     } 
     else { 
      majorityMap.put(distances[i][1], 1); 
     } 
     ++i; 
    } 

    // Find the class code with the highest frequency 
    int maxVal = -1; 
    for (Entry<Integer, Integer> entry: majorityMap.entrySet()) { 
     int entryVal = entry.getValue(); 
     if(entryVal > maxVal) { 
      maxVal = entryVal; 
      label = entry.getKey(); 
     } 
    } 

    return label; 
} 

所有你需要做的是將ķ作爲參數。但請記住,上面的代碼不能以特定的方式處理關係。

+0

感謝您的幫助。在看過你的最初嘗試後,我發現我的問題是什麼,真的很有幫助。 – Ben411916