對於我的APCS最後一個項目,我在做一個應用程序:如何訓練OCR神經網絡?
- 允許用戶繪製圖紙面板上的數字;
- 縮放/翻譯每個筆畫(由x-y座標列表表示)爲100x100;
- 從縮放的筆畫產生圖像;
- 從該圖像產生二進制二維數組(0代表白色,否則爲1);
- 並將該二進制數組傳遞給神經元對象以進行字符識別。
下面的類代表神經元:
import java.awt.*;
import java.util.*;
import java.io.*;
public class Neuron
{
private double[][] weights;
public static double LEARNING_RATE = 0.01;
/**
*Initialize weights
*Assign random double values to weights
*/
public Neuron(int r, int c)
{
weights = new double[r][c];
PrintWriter printer = null;
try
{
printer = new PrintWriter("training.txt");
}
catch (FileNotFoundException e) {};
for (int i = 0; i < weights.length; i++)
{
for (int j = 0; j < weights[i].length; j++)
{
weights[i][j] = 2 * Math.random() - 1; //Generates random number between -1 and 1
if (j < weights[i].length - 1)
printer.print(weights[i][j] + " ");
else
printer.print(weights[i][j]);
}
printer.println();
}
printer.close();
}
public Neuron(String fileName)
{
File data = new File(fileName);
Scanner input = null;
try
{
input = new Scanner(data);
}
catch (FileNotFoundException e)
{
System.out.println("Error: could not open " + fileName);
System.exit(1);
}
int r = Drawing.DEF_HEIGHT, c = Drawing.DEF_WIDTH;
weights = new double[r][c];
int i = 0, j = 0;
while (input.hasNext())
{
weights[i][j] = input.nextDouble();
j++;
if (j > weights[i].length - 1)
{
i++;
j = 0;
}
}
for (double[] a : weights)
System.out.println(Arrays.toString(a));
}
/**
*1. Initialize a sum variable
*2. Multiply each index of weights by each index of bin
*3. Sum these values
*4. Return the activated sum
*/
public int feedforward(int[][] bin) //bin represents 2D array of binary values for a binary image
{
double sum = 0;
for (int i = 0; i < weights.length; i++)
{
for (int j = 0; j < weights[i].length; j++)
sum += weights[i][j] * bin[i][j];
}
return activate(sum);
}
/**
*1. Generate a sigmoid (logistic) value from a sum
*2. "Digitize" the sigmoid value
*3. Return the digitized value, which corresponds to a number
*/
public int activate(double n)
{
double sig = 1.0/(1+Math.exp(-1*n));
int digitized = 0;
if (sig < 0.1)
digitized = 0;
else if (sig >= 0.1 && sig < 0.2)
digitized = 1;
else if (sig >= 0.2 && sig < 0.3)
digitized = 2;
else if (sig >= 0.3 && sig < 0.4)
digitized = 3;
else if (sig >= 0.4 && sig < 0.5)
digitized = 4;
else if (sig >= 0.5 && sig < 0.6)
digitized = 5;
else if (sig >= 0.6 && sig < 0.7)
digitized = 6;
else if (sig >= 0.7 && sig < 0.8)
digitized = 7;
else if (sig >= 0.8 && sig < 0.9)
digitized = 8;
else if (sig >= 0.9)
digitized = 9;
System.out.println("Sigmoid value: " + sig + "\nDigitized value: " + digitized);
return digitized;
}
/**
* 1. Provide inputs and "known" answer
* 2. Guess according to the inputs using feedforward(inputs)
* 3. Compute the error
* 4. Adjust all weights according to the error and learning rate
*/
public void train(int[][] bin, int desired)
{
int guess = feedforward(bin);
int error = desired-guess;
for (int i = 0; i < weights.length; i++)
{
for (int j = 0; j < weights[i].length; j++)
weights[i][j] += LEARNING_RATE * error * bin[i][j];
}
}
}
我使用不同的類來「訓練」的神經元。這個其他類TrainingConsole.java基本上採用隨機生成的組件的「training.txt」,爲其提供訓練示例(圖像 - >二維二維數組),並根據錯誤,學習速率和相應的值調整權重對於bin:
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.Arrays;
import java.util.Scanner;
import javax.imageio.ImageIO;
public class TrainingConsole
{
private File folder;
private File data;
public TrainingConsole(String dataFileName, String folderName)
{
data = new File(dataFileName);
folder = new File(folderName);
}
public void changeFolder(String folderName)
{
folder = new File(folderName);
}
public void feedAll(int desired)
{
System.out.println(Arrays.toString(folder.listFiles()));
for (int i = 1; i < folder.listFiles().length; i++) //To exclude folder
{
BufferedImage img = new BufferedImage(Drawing.DEF_WIDTH,Drawing.DEF_HEIGHT,BufferedImage.TYPE_INT_RGB);
try
{
String name = folder.listFiles()[i].getName();
if (name.substring(name.length()-4).equals(".png"))
img = ImageIO.read(folder.listFiles()[i]);
}
catch(IOException e)
{System.out.println("Error?");}
int[][] bin = new int[Drawing.DEF_WIDTH][Drawing.DEF_HEIGHT];
if (img != null)
{
for (int y = 0; y < img.getHeight(); y++)
{
for (int x = 0; x < img.getWidth(); x++)
{
int rgb = img.getRGB(x,y);
//System.out.println(rgb);
if (rgb == -1) //White
bin[y][x] = 0;
else
bin[y][x] = 1;
}
}
for (int[] a : bin)
System.out.println(Arrays.toString(a));
train(bin,desired);
}
}
}
public void train(int[][] bin, int desired) {
int guess = feedforward(bin);
int error = desired - guess;
Scanner input = null;
try {
input = new Scanner(data);
} catch (FileNotFoundException e) {
System.exit(1);
}
double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH];
int i = 0, j = 0;
while (input.hasNext() && i < Drawing.DEF_HEIGHT) {
weights[i][j] = input.nextDouble();
j++;
if (j > weights[i].length - 1) {
i++;
j = 0;
}
}
for (int k = 0; k < weights.length; k++) {
for (int l = 0; l < weights[k].length; l++)
weights[k][l] += IMGNeuron.LEARNING_RATE * error * bin[k][l];
}
data = new File(data.getName());
PrintWriter output = null;
try {
output = new PrintWriter(data);
} catch (FileNotFoundException e) {
System.out.println("Cannot find data");
}
for (int m = 0; m < weights.length; m++) {
for (int n = 0; n < weights[m].length - 1; n++)
output.print(weights[m][n] + " ");
output.print(weights[m][weights[m].length - 1]);
output.println();
}
output.close();
}
public int feedforward(int[][] bin)
{
double sum = 0;
Scanner input = null;
try
{
input = new Scanner(data);
}
catch(FileNotFoundException e)
{
System.out.println("Could not locate data");
}
double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH];
int i = 0, j = 0;
while (i < Drawing.DEF_HEIGHT && j < Drawing.DEF_WIDTH)
{
//System.out.println("(" + i + " , " + j + ")");
weights[i][j] = input.nextDouble();
j++;
if (j > weights[i].length - 1)
{
i++;
j = 0;
}
}
for (int m = 0; m < weights.length; m++)
{
for (int n = 0; n < weights[m].length; n++)
sum += weights[m][n] * bin[m][n];
}
return activate(sum);
}
public int activate(double n)
{
double sig = 1.0/(1+Math.exp(-1*n));
int digitized = 0;
if (sig < 0.1)
digitized = 0;
else if (sig >= 0.1 && sig < 0.2)
digitized = 1;
else if (sig >= 0.2 && sig < 0.3)
digitized = 2;
else if (sig >= 0.3 && sig < 0.4)
digitized = 3;
else if (sig >= 0.4 && sig < 0.5)
digitized = 4;
else if (sig >= 0.5 && sig < 0.6)
digitized = 5;
else if (sig >= 0.6 && sig < 0.7)
digitized = 6;
else if (sig >= 0.7 && sig < 0.8)
digitized = 7;
else if (sig >= 0.8 && sig < 0.9)
digitized = 8;
else if (sig >= 0.9)
digitized = 9;
return digitized;
}
public static void main(String[] args)
{
Scanner input = new Scanner(System.in);
TrainingConsole trainer = new TrainingConsole("training.txt","Training_000");
System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");
System.out.println("Training Console");
System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");
for (int i = 0; i <= 9; i++) {
//System.out.print("Folder with training data for desired = " + i + ", or enter \"skip\" to skip: ");
//String folderName = input.nextLine().trim();
String folderName = "Training_00" + i;
//System.out.println(folderName);
if (!folderName.toLowerCase().equals("skip"))
{
trainer.changeFolder(folderName);
// System.out.print("Press enter to run: ");
// String noReason = input.nextLine();
trainer.feedAll(i);
}
System.out.println("----------------------------------------------------------------------------------------------------ava----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");
}
}
}
對於後續的神經元構造,我傳遞「training.txt」作爲權重矩陣。但是,這顯然不起作用:enter image description here
請幫忙!我對神經網絡和機器學習非常陌生。此時,我不知道自己做錯了什麼:我是否需要更多培訓示例?我是否實施了不良的激活功能?任何意見,將不勝感激。另外,如果需要,隨時可以要求額外的代碼。
也許我錯過了它,但我只在您的網絡中看到一個神經元。無論訓練的數量如何,這都會表現得很差。你可能想閱讀http://neuralnetworksanddeeplearning.com/chap1.html。它有一個與你想要做的非常相似的例子。 – Chill
您也使用sigmoid(int激活函數)是不好的......那不是你怎麼做multiclass – user2717954