1

對於我的APCS最後一個項目,我在做一個應用程序:如何訓練OCR神經網絡?

  • 允許用戶繪製圖紙面板上的數字;
  • 縮放/翻譯每個筆畫(由x-y座標列表表示)爲100x100;
  • 從縮放的筆畫產生圖像;
  • 從該圖像產生二進制二維數組(0代表白色,否則爲1);
  • 並將該二進制數組傳遞給神經元對象以進行字符識別。

下面的類代表神經元:

import java.awt.*; 
import java.util.*; 
import java.io.*; 

public class Neuron 
{ 
    private double[][] weights; 
    public static double LEARNING_RATE = 0.01; 

    /** 
    *Initialize weights 
    *Assign random double values to weights 
    */ 
    public Neuron(int r, int c) 
    { 
     weights = new double[r][c]; 

     PrintWriter printer = null; 
     try 
     { 
      printer = new PrintWriter("training.txt"); 
     } 
     catch (FileNotFoundException e) {}; 
     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
      { 
       weights[i][j] = 2 * Math.random() - 1; //Generates random number between -1 and 1 
       if (j < weights[i].length - 1) 
        printer.print(weights[i][j] + " "); 
       else 
        printer.print(weights[i][j]); 
      } 
      printer.println(); 
     } 
     printer.close(); 
    } 

    public Neuron(String fileName) 
    { 
     File data = new File(fileName); 
     Scanner input = null; 
     try 
     { 
      input = new Scanner(data); 
     } 
     catch (FileNotFoundException e) 
     { 
      System.out.println("Error: could not open " + fileName); 
      System.exit(1); 
     } 

     int r = Drawing.DEF_HEIGHT, c = Drawing.DEF_WIDTH; 
     weights = new double[r][c]; 

     int i = 0, j = 0; 
     while (input.hasNext()) 
     { 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) 
      { 
       i++; 
       j = 0; 
      } 
     } 

     for (double[] a : weights) 
      System.out.println(Arrays.toString(a)); 

    } 

    /** 
    *1. Initialize a sum variable 
    *2. Multiply each index of weights by each index of bin 
    *3. Sum these values 
    *4. Return the activated sum 
    */ 
    public int feedforward(int[][] bin) //bin represents 2D array of binary values for a binary image 
    { 
     double sum = 0; 
     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
       sum += weights[i][j] * bin[i][j]; 
     } 
     return activate(sum); 
    } 

    /** 
    *1. Generate a sigmoid (logistic) value from a sum 
    *2. "Digitize" the sigmoid value 
    *3. Return the digitized value, which corresponds to a number 
    */ 
    public int activate(double n) 
    { 
     double sig = 1.0/(1+Math.exp(-1*n)); 
     int digitized = 0; 

     if (sig < 0.1) 
      digitized = 0; 
     else if (sig >= 0.1 && sig < 0.2) 
      digitized = 1; 
     else if (sig >= 0.2 && sig < 0.3) 
      digitized = 2; 
     else if (sig >= 0.3 && sig < 0.4) 
      digitized = 3; 
     else if (sig >= 0.4 && sig < 0.5) 
      digitized = 4; 
     else if (sig >= 0.5 && sig < 0.6) 
      digitized = 5; 
     else if (sig >= 0.6 && sig < 0.7) 
      digitized = 6; 
     else if (sig >= 0.7 && sig < 0.8) 
      digitized = 7; 
     else if (sig >= 0.8 && sig < 0.9) 
      digitized = 8; 
     else if (sig >= 0.9) 
      digitized = 9; 

     System.out.println("Sigmoid value: " + sig + "\nDigitized value: " + digitized); 
     return digitized; 
    } 

    /** 
    * 1. Provide inputs and "known" answer 
    * 2. Guess according to the inputs using feedforward(inputs) 
    * 3. Compute the error 
    * 4. Adjust all weights according to the error and learning rate 
    */ 
    public void train(int[][] bin, int desired) 
    { 
     int guess = feedforward(bin); 
     int error = desired-guess; 

     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
       weights[i][j] += LEARNING_RATE * error * bin[i][j]; 
     } 
    } 

} 

我使用不同的類來「訓練」的神經元。這個其他類TrainingConsole.java基本上採用隨機生成的組件的「training.txt」,爲其提供訓練示例(圖像 - >二維二維數組),並根據錯誤,學習速率和相應的值調整權重對於bin:

import java.awt.image.BufferedImage; 
import java.io.*; 
import java.util.Arrays; 
import java.util.Scanner; 

import javax.imageio.ImageIO; 

public class TrainingConsole 
{ 

    private File folder; 
    private File data; 

    public TrainingConsole(String dataFileName, String folderName) 
    { 
     data = new File(dataFileName); 
     folder = new File(folderName); 
    } 

    public void changeFolder(String folderName) 
    { 
     folder = new File(folderName); 
    } 

    public void feedAll(int desired) 
    { 
     System.out.println(Arrays.toString(folder.listFiles())); 
     for (int i = 1; i < folder.listFiles().length; i++) //To exclude folder 
     { 
      BufferedImage img = new BufferedImage(Drawing.DEF_WIDTH,Drawing.DEF_HEIGHT,BufferedImage.TYPE_INT_RGB); 
      try 
      { 

       String name = folder.listFiles()[i].getName(); 
       if (name.substring(name.length()-4).equals(".png")) 
        img = ImageIO.read(folder.listFiles()[i]); 
      } 
      catch(IOException e) 
      {System.out.println("Error?");} 

      int[][] bin = new int[Drawing.DEF_WIDTH][Drawing.DEF_HEIGHT]; 

      if (img != null) 
      { 
       for (int y = 0; y < img.getHeight(); y++) 
       { 
        for (int x = 0; x < img.getWidth(); x++) 
        { 
         int rgb = img.getRGB(x,y); 
         //System.out.println(rgb); 
         if (rgb == -1) //White 
          bin[y][x] = 0; 
         else 
          bin[y][x] = 1; 
        } 
       } 
       for (int[] a : bin) 
        System.out.println(Arrays.toString(a)); 
       train(bin,desired); 
      } 
     } 
    } 

    public void train(int[][] bin, int desired) { 
     int guess = feedforward(bin); 
     int error = desired - guess; 

     Scanner input = null; 
     try { 
      input = new Scanner(data); 
     } catch (FileNotFoundException e) { 
      System.exit(1); 
     } 
     double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH]; 
     int i = 0, j = 0; 
     while (input.hasNext() && i < Drawing.DEF_HEIGHT) { 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) { 
       i++; 
       j = 0; 
      } 
     } 

     for (int k = 0; k < weights.length; k++) { 
      for (int l = 0; l < weights[k].length; l++) 
       weights[k][l] += IMGNeuron.LEARNING_RATE * error * bin[k][l]; 
     } 

     data = new File(data.getName()); 
     PrintWriter output = null; 
     try { 
      output = new PrintWriter(data); 
     } catch (FileNotFoundException e) { 
      System.out.println("Cannot find data"); 
     } 
     for (int m = 0; m < weights.length; m++) { 
      for (int n = 0; n < weights[m].length - 1; n++) 
       output.print(weights[m][n] + " "); 
      output.print(weights[m][weights[m].length - 1]); 
      output.println(); 
     } 
     output.close(); 
    } 

    public int feedforward(int[][] bin) 
    { 
     double sum = 0; 

     Scanner input = null; 
     try 
     { 
      input = new Scanner(data); 
     } 
     catch(FileNotFoundException e) 
     { 
      System.out.println("Could not locate data"); 
     } 
     double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH]; 
     int i = 0, j = 0; 
     while (i < Drawing.DEF_HEIGHT && j < Drawing.DEF_WIDTH) 
     { 
      //System.out.println("(" + i + " , " + j + ")"); 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) 
      { 
       i++; 
       j = 0; 
      } 
     } 

     for (int m = 0; m < weights.length; m++) 
     { 
      for (int n = 0; n < weights[m].length; n++) 
       sum += weights[m][n] * bin[m][n]; 
     } 
     return activate(sum); 
    } 

    public int activate(double n) 
    { 
     double sig = 1.0/(1+Math.exp(-1*n)); 
     int digitized = 0; 

     if (sig < 0.1) 
      digitized = 0; 
     else if (sig >= 0.1 && sig < 0.2) 
      digitized = 1; 
     else if (sig >= 0.2 && sig < 0.3) 
      digitized = 2; 
     else if (sig >= 0.3 && sig < 0.4) 
      digitized = 3; 
     else if (sig >= 0.4 && sig < 0.5) 
      digitized = 4; 
     else if (sig >= 0.5 && sig < 0.6) 
      digitized = 5; 
     else if (sig >= 0.6 && sig < 0.7) 
      digitized = 6; 
     else if (sig >= 0.7 && sig < 0.8) 
      digitized = 7; 
     else if (sig >= 0.8 && sig < 0.9) 
      digitized = 8; 
     else if (sig >= 0.9) 
      digitized = 9; 

     return digitized; 
    } 

    public static void main(String[] args) 
    { 
     Scanner input = new Scanner(System.in); 
     TrainingConsole trainer = new TrainingConsole("training.txt","Training_000"); 

     System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 
     System.out.println("Training Console"); 
     System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 

     for (int i = 0; i <= 9; i++) { 
      //System.out.print("Folder with training data for desired = " + i + ", or enter \"skip\" to skip: "); 
      //String folderName = input.nextLine().trim(); 
      String folderName = "Training_00" + i; 
      //System.out.println(folderName); 
      if (!folderName.toLowerCase().equals("skip")) 
      { 
       trainer.changeFolder(folderName); 
//    System.out.print("Press enter to run: "); 
//    String noReason = input.nextLine(); 
       trainer.feedAll(i); 
      } 
      System.out.println("----------------------------------------------------------------------------------------------------ava----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 
     } 
    } 

} 

對於後續的神經元構造,我傳遞「training.txt」作爲權重矩陣。但是,這顯然不起作用:enter image description here

請幫忙!我對神經網絡和機器學習非常陌生。此時,我不知道自己做錯了什麼:我是否需要更多培訓示例?我是否實施了不良的激活功能?任何意見,將不勝感激。另外,如果需要,隨時可以要求額外的代碼。

+0

也許我錯過了它,但我只在您的網絡中看到一個神經元。無論訓練的數量如何,這都會表現得很差。你可能想閱讀http://neuralnetworksanddeeplearning.com/chap1.html。它有一個與你想要做的非常相似的例子。 – Chill

+0

您也使用sigmoid(int激活函數)是不好的......那不是你怎麼做multiclass – user2717954

回答

0

正如評論中指出的那樣,有兩個主要問題,我會在更詳細的敘述它們。

  1. 你的整個模型是單感知,也就是說,你從你的輸入空間(像素)創建一個線性模型類(數字)。這根本行不通,它不是現代意義上的神經網絡。設計用於圖像處理的「現代」NN將由神經元的千個組成,在中連接,其間具有非線性激活,可能以卷積核的形式排列(因爲這是用於圖像識別的最先進的體系結構)。

  2. 你應該解決多類問題,但你實際上排名。爲了讓NN分類爲K類,你應該有K個輸出神經元,每一個都會產生一個信號,解釋爲屬於特定類的「概率」(不是嚴格的數學意義上的),因此爲了分類 - arg max(最高值的神經元數)。

一旦你解決與整個架構,你應該開始得到合理的結果,這兩個重要的問題,那麼唯一缺失的部分是調整超參數和獲得更多的訓練數據。