2017-01-21 71 views
1

我必須爲一個學校項目創建一個OCR Programm,所以我開始在維基百科的幫助下創建反向傳播算法。爲了訓練我的網絡,我使用了幾天前提取的MNIST數據庫,以便獲得真實的圖像文件。但現在錯誤總是在237左右,訓練一段時間後,錯誤和權重變成了NaN。我的代碼有什麼問題?神經網絡:反向傳播不工作(Java)

A screenshot of my images folder

這裏是我的主類,其中應培養我網:

package de.Marcel.NeuralNetwork; 

import java.awt.Color; 
import java.awt.image.BufferedImage; 
import java.io.File; 
import java.io.IOException; 

import javax.imageio.ImageIO; 

public class OCR { 
    public static void main(String[] args) throws IOException { 
     // create network 
     NeuralNetwork net = new NeuralNetwork(784, 450, 5, 0.2); 

    // load Images 
    File file = new File("images"); 

    int images= 0; 
    double error = 0; 
    for (File f : file.listFiles()) { 
     BufferedImage image = ImageIO.read(f); 

     int t = -1; 
     double[] pixels = new double[784]; 
     for (int x = 0; x < image.getWidth(); x++) { 
      for (int y = 0; y < image.getHeight(); y++) { 
       t++; 
       Color c = new Color(image.getRGB(x, y)); 

       if (c.getRed() == 0 && c.getGreen() == 0 && c.getBlue() == 0) { 
        pixels[t] = 1; 
       } else if (c.getRed() == 255 && c.getGreen() == 255 && c.getBlue() == 255) { 
        pixels[t] = 0; 
       } 
      } 
     } 

     try { 
      if (f.getName().startsWith("1")) { 
       net.learn(pixels, new double[] { 1, 0, 0, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("2")) { 
       net.learn(pixels, new double[] { 0, 1, 0, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("3")) { 
       net.learn(pixels, new double[] { 0, 0, 1, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("4")) { 
       net.learn(pixels, new double[] { 0, 0, 0, 1, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("5")) { 
       net.learn(pixels, new double[] { 0, 0, 0, 0, 1 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("6")) { 
       break; 
      } 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
    } 

    error = error/iterations; 

    System.out.println("Trained images: " + images); 
    System.out.println("Error: " + error); 

    //save 
    System.out.println("Save"); 
    try { 
     net.saveNetwork("network.nnet"); 
    } catch (Exception e) { 
     e.printStackTrace(); 
    } 
} 
} 

...這是我的神經元類:

package de.Marcel.NeuralNetwork; 

public class Neuron { 
    private double input, output; 

public Neuron() { 

} 

public void setInput(double input) { 
    this.input = input; 
} 

public void setOutput(double output) { 
    this.output = output; 
} 

public double getInput() { 
    return input; 
} 

public double getOutput() { 
    return output; 
} 

}

。 ..最後我的NeuralNetwork

package de.Marcel.NeuralNetwork; 

import java.io.File; 
import java.io.FileWriter; 
import java.util.Random; 

public class NeuralNetwork { 
    private Neuron[] inputNeurons, hiddenNeurons, outputNeurons; 
    private double[] weightMatrix1, weightMatrix2; 
    private double learningRate, error; 

public NeuralNetwork(int inputCount, int hiddenCount, int outputCount, double learningRate) { 
    this.learningRate = learningRate; 

    // create Neurons 
    // create Input 
    this.inputNeurons = new Neuron[inputCount]; 
    for (int i = 0; i < inputCount; i++) { 
     this.inputNeurons[i] = new Neuron(); 
    } 
    // createHidden 
    this.hiddenNeurons = new Neuron[hiddenCount]; 
    for (int i = 0; i < hiddenCount; i++) { 
     this.hiddenNeurons[i] = new Neuron(); 
    } 
    // createOutput 
    this.outputNeurons = new Neuron[outputCount]; 
    for (int i = 0; i < outputCount; i++) { 
     this.outputNeurons[i] = new Neuron(); 
    } 

    // create weights 
    Random random = new Random(); 
    // weightMatrix1 
    this.weightMatrix1 = new double[inputCount * hiddenCount]; 
    for (int i = 0; i < inputCount * hiddenCount; i++) { 
     this.weightMatrix1[i] = (random.nextDouble() * 2 - 1)/0.25; 
    } 
    // weightMatrix2 
    this.weightMatrix2 = new double[hiddenCount * outputCount]; 
    for (int i = 0; i < hiddenCount * outputCount; i++) { 
     this.weightMatrix2[i] = (random.nextDouble() * 2 - 1)/0.25; 
    } 
} 

public void calculate(double[] input) throws Exception { 
    // verfiy input length 
    if (input.length == inputNeurons.length) { 
     // forwardPropagation 
     // set input array as input and output of input neurons 
     for (int i = 0; i < input.length; i++) { 
      inputNeurons[i].setInput(input[i]); 
      inputNeurons[i].setOutput(input[i]); 
     } 

     // calculate output of hiddenNeurons 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int i = 0; i < inputNeurons.length; i++) { 
       Neuron iNeuron = inputNeurons[i]; 
       totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i]; 
      } 

      // set input 
      hNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      hNeuron.setOutput(calculatedOutput); 
     } 

     // calculate output of outputNeurons 
     for (int o = 0; o < outputNeurons.length; o++) { 
      Neuron oNeuron = outputNeurons[o]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 
       totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h]; 
      } 

      // set input 
      oNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      oNeuron.setOutput(calculatedOutput); 
     } 
    } else { 
     throw new Exception("[NeuralNetwork] input array is either too small or to big"); 
    } 
} 

public void learn(double[] input, double[] output) throws Exception { 
    double partialOutput = 0; 

    // verfiy input length 
    if (input.length == inputNeurons.length) { 
     // forwardPropagation 
     // set input array as input and output of input neurons 
     for (int i = 0; i < input.length; i++) { 
      inputNeurons[i].setInput(input[i]); 
      inputNeurons[i].setOutput(input[i]); 
     } 

     // calculate output of hiddenNeurons 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int i = 0; i < inputNeurons.length; i++) { 
       Neuron iNeuron = inputNeurons[i]; 
       totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i]; 
      } 

      // set input 
      hNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      hNeuron.setOutput(calculatedOutput); 
     } 

     // calculate output of outputNeurons 
     for (int o = 0; o < outputNeurons.length; o++) { 
      Neuron oNeuron = outputNeurons[o]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 
       totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h]; 
      } 

      // set input 
      oNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      oNeuron.setOutput(calculatedOutput); 
     } 

     // backPropagation 
     double totalError = 0; 
     // calculate weights in matrix2 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 

      for (int o = 0; o < outputNeurons.length; o++) { 
       Neuron oNeuron = outputNeurons[o]; 

       // calculate weight 
       double delta = learningRate * derivativeSigmoid(oNeuron.getInput()) 
         * (output[o] - oNeuron.getOutput()) * hNeuron.getOutput(); 

       // set new weight 
       weightMatrix2[h + o * hiddenNeurons.length] = weightMatrix2[h + o * hiddenNeurons.length] + delta; 

       // update partial output 
       partialOutput += (derivativeSigmoid(oNeuron.getInput()) * (output[o] - oNeuron.getOutput()) 
         * weightMatrix2[h + o * hiddenNeurons.length]); 

       //calculate error 
       totalError += Math.pow((output[o] - oNeuron.getOutput()), 2); 
      } 
     } 

     //set error 
     this.error = 0.5 * totalError; 

     // calculate weights in matrix1 
     for (int i = 0; i < inputNeurons.length; i++) { 
      Neuron iNeuron = inputNeurons[i]; 

      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 

       // calculate weight 
       double delta = learningRate * derivativeSigmoid(hNeuron.getInput()) * partialOutput 
         * (iNeuron.getOutput()); 

       // set new weight 
       weightMatrix1[i + h * inputNeurons.length] = weightMatrix1[i + h * inputNeurons.length] + delta; 
      } 
     } 
    } else { 
     throw new Exception("[NeuralNetwork] input array is either too small or to big"); 
    } 
} 

// save Network 
public void saveNetwork(String fileName) throws Exception { 
    File file = new File(fileName); 
    FileWriter writer = new FileWriter(file); 

    writer.write("weightmatrix1:"); 
    writer.write(System.lineSeparator()); 

    // write weightMatrix1 
    for (double d : weightMatrix1) { 
     writer.write(d + "-"); 
    } 

    writer.write(System.lineSeparator()); 
    writer.write("weightmatrix2:"); 
    writer.write(System.lineSeparator()); 

    // write weightMatrix2 
    for (double d : weightMatrix2) { 
     writer.write(d + "-"); 
    } 

    // save 
    writer.close(); 
} 

// sigmoid function 
private double sigmoid(double input) { 
    return Math.exp(input * (-1)); 
} 

private double derivativeSigmoid(double input) { 
    return sigmoid(input) * (1 - sigmoid(input)); 
} 

public double getError() { 
    return error; 
} 
} 
+0

NaN表示不是數字。它發生在你被零除的時候。例如,double error = 10/0;錯誤將等於NaN。此外,如果你做doubleVar = 1 +錯誤; anotherVar也會是NaN。 – Zack

回答

0

看起來你的sigmoid函數是不正確的。它應該是1 /(1 + exp(-x))。

如果您仍然遇到NaN錯誤,可能是因爲使用函數可能會導致矯枉過正,尤其是對於大數字(即小於-10和大於10的數字)。

使用sigmoid(x)的預先計算值的數組可能會阻止較大數據集的此問題,並且還可以幫助程序更高效地運行。

希望這會有所幫助!

相關問題