2

我剛開始學習機器學習和神經網絡,所以我仍在努力理解反向傳播是如何工作的。 我試圖在Java中開發一個簡單的NN,使用簡單的基於矩陣的方法。如果我只放一個訓練樣例,則網絡完美工作,但如果我嘗試使用更多,則輸出始終是訓練期望輸出的平均值。 http://neuralnetworksanddeeplearning.com/images/tikz21.png反向傳播神經網絡無法正常工作

package neuralnetwork; 
/** 
* @author Paolo Pellizzoni 
*/ 

public class NeuralNetwork { 

static final int in_l = 2; 
static final int h_l = 5; 
static final int out_l = 1; 

public static double[][] w2 = new double[h_l][in_l]; 
public static double[][] w3 = new double[out_l][h_l]; 
public static double[] b2 = new double[h_l]; 
public static double[] b3 = new double[out_l]; 

public static double[][] x = {{3,4},{2,3}}; 
public static double[][] y = {{0.3,0.7}}; 
public static double[][] test = {{3}, {2}}; 
// using x = {{3},{2}} and y = {{0.3}} it works 

    public static void main(String[] args) { 
     trainNN(0.2); 
     double[][] m = a_3(test); 

     for(int i=0; i<m.length; i++){ 
      for(int j=0; j<m[0].length; j++){ 
     System.out.print(m[i][j]+" "); 
      } 
      System.out.println(); 
    } 
    } 
    // ---------- FUNCTIONS ---------- 

    static void inizialize_weights(double[][] m){ 
    for(int i=0; i<m.length; i++){ 
      for(int j=0; j<m[0].length; j++){ 
     m[i][j]= Math.random(); 
      } 
    } 
    } 
    static void trainNN(double rate){ 
     inizialize_weights(w2); 
     inizialize_weights(w3); 

     for(int c=0; c<500; c++){ 
      double[][] dJ_w3 = dJ_w3(x, y); 
      double[][] dJ_w2 = dJ_w2(x, y); 
      double[] dJ_b3 = dJ_b3(x, y); 
      double[] dJ_b2 = dJ_b2(x, y); 
      w3 = matrix_sum(w3, dJ_w3, -rate); 
      w2 = matrix_sum(w2, dJ_w2, -rate); 
      b3 = vect_sum(b3, dJ_b3, -rate); 
      b2 = vect_sum(b2, dJ_b2, -rate); 
     } 
    } 

    static double[][] a_3(double[][] inputs){ 
     return sigmoid(z_3(inputs)); 
    } 
    static double[][] z_3(double[][] inputs){ 
     return matrix_sum_vect(matrix_product(w3, a_2(inputs)), b3, 1); 
    } 
    static double[][] a_2(double[][] inputs){ 
     return sigmoid(z_2(inputs)); 
    } 
    static double[][] z_2(double[][] inputs){ 
     return matrix_sum_vect(matrix_product(w2, inputs), b2, 1); 
    } 

    static double[][] delta3 (double[][] inputs, double[][] y){ 
     return matrix_hadamard(
       matrix_sum(a_3(inputs), y, -1), 
       sigmoid_prime(z_3(inputs)) 
     ); 
    } 
    static double[][] delta2 (double[][] inputs, double[][] y){ 
     return matrix_hadamard(
       matrix_product(
         transpose_matrix(w3), 
         delta3(inputs, y)), 
       sigmoid_prime(z_2(inputs)) 
     ); 
    } 
    static double[][] dJ_w3 (double[][] inputs, double[][] y){ 
     double[][] dJ_w3 = new double[out_l][h_l]; 
     double[][] delta3 = delta3(inputs, y); 
     double[][] a2 = a_2(inputs); 
     for(int i=0; i<delta3.length; i++){ 
      for(int j=0; j<a2.length; j++){ 
       double tmp = 0; 
       for(int k=0; k<a2[0].length; k++){ 
        tmp += a2[j][k]*delta3[i][k]; 
       } 
       dJ_w3[i][j] = tmp/a2[0].length; 
      } 
     } 

     return dJ_w3; 
    } 
    static double[][] dJ_w2 (double[][] inputs, double[][] y){ 
     double[][] dJ_w2 = new double[h_l][in_l]; 
     double[][] delta2 = delta2(inputs, y); 
     double[][] a1 = inputs; 

     for(int i=0; i<delta2.length; i++){ 
      for(int j=0; j<a1.length; j++){ 
       double tmp = 0; 
       for(int k=0; k<a1[0].length; k++){ 
        tmp += a1[j][k]*delta2[j][k]; 
       } 
       dJ_w2[i][j] = tmp/a1[0].length; 
      } 
     } 

     return dJ_w2; 
    } 
    static double[] dJ_b3 (double[][] inputs, double[][] y){ 
     double[] dJ_b3 = new double[out_l]; 
     double[][] delta3 = delta3(inputs, y); 
     for(int i=0; i<delta3.length; i++){ 
      double tmp = 0; 
      for(int k=0; k<delta3[0].length; k++){ 
       tmp += delta3[i][k]; 
      } 
      dJ_b3[i] = tmp/delta3[0].length; 
     } 

     return dJ_b3; 
    } 
    static double[] dJ_b2 (double[][] inputs, double[][] y){ 
     double[] dJ_b2 = new double[h_l]; 
     double[][] delta2 = delta2(inputs, y); 
     for(int i=0; i<delta2.length; i++){ 
      double tmp = 0; 
      for(int k=0; k<delta2[0].length; k++){ 
       tmp += delta2[i][k]; 
      } 
      dJ_b2[i] = tmp/delta2[0].length; 
     } 

     return dJ_b2; 
    } 


    // ----- Math ----- 


    static double[][] matrix_product(double[][] a, double[][] b){ // matrix multiplication 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length; 
     if(m1ColLength != m2RowLength) return null; 
     int mRRowLength = a.length;  
     int mRColLength = b[0].length; 
     double[][] mResult = new double[mRRowLength][mRColLength]; 
     for(int i = 0; i < mRRowLength; i++) {   
      for(int j = 0; j < mRColLength; j++) {  
       for(int k = 0; k < m1ColLength; k++) { 
        mResult[i][j] += a[i][k] * b[k][j]; 
       } 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_sum(double[][] a, double[][] b, double is_sum){ //matrix sum 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     int m2ColLength = b[0].length; 
     if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]+(b[i][j])*is_sum; 
      } 
     } 
     return mResult; 
    } 
    static double[] vect_sum(double[] a, double[] b, double is_sum){ // vector sum 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     if(m1RowLength != m2RowLength) return null; 
     double[] mResult = new double[m1RowLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      mResult[i]=a[i]+(b[i])*is_sum; 
     } 
     return mResult; 
    } 
    static double[][] matrix_sum_vect(double[][] a, double[] b, double is_sum){ // adds a vector to each column 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     if(m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]+(b[i])*is_sum; 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_hadamard(double[][] a, double[][] b){ // hadamard product 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     int m2ColLength = b[0].length; 
     if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]*b[i][j]; 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_x_scalar(double[][] a, double scalar){ // matrix times scalar 
     int m1ColLength = a[0].length; 
     int m1RowLength = a.length;  
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]*scalar; 
      } 
     } 
     return mResult; 
    } 
    static double[][] transpose_matrix(double [][] m){ 
     double[][] mResult = new double[m[0].length][m.length]; 
     for (int i = 0; i < m.length; i++) 
      for (int j = 0; j < m[0].length; j++) 
       mResult[j][i] = m[i][j]; 
     return mResult; 
    } 
    static double sigmoid(double z) { 
    return 1.0/(1.0+Math.exp(-z)); 
    } 
    static double[][] sigmoid(double[][] z) { 
     for(int i=0; i<z.length; i++){ 
      for(int j=0; j<z[0].length; j++){ 
       z[i][j]= sigmoid(z[i][j]); 
      } 
     } 
    return z; 
    } 
    static double sigmoid_prime(double z) { 
    return sigmoid(z)*(1-sigmoid(z)); 
    } 
    static double[][] sigmoid_prime(double[][] z) { 
     for(int i=0; i<z.length; i++){ 
      for(int j=0; j<z[0].length; j++){ 
       z[i][j]= sigmoid_prime(z[i][j]); 
      } 
     } 
    return z; 
    }// ----- end math ----- 







} 

我敢肯定的是,錯誤的dJ_w3, dJ_w2功能隱藏,也許在K循環,平均所有的梯度,但我只是無法找到它。 你能幫我嗎?

回答

0

發現問題,我只是不得不將訓練迭代次數增加到50000.

相關問題