2011-03-29 59 views
2

我正在做一些關於Java 7中的叉/加入框架的性能研究。爲了改善測試結果,我想在測試過程中使用不同的遞歸算法。其中之一是乘法矩陣。在Java中叉加入矩陣乘法

public class MatrixMultiply { 

    static final int DEFAULT_GRANULARITY = 16; 

    /** The quadrant size at which to stop recursing down 
    * and instead directly multiply the matrices. 
    * Must be a power of two. Minimum value is 2. 
    **/ 
    static int granularity = DEFAULT_GRANULARITY; 

    public static void main(String[] args) { 

    final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16"; 

    try { 
     int procs; 
     int n; 
     try { 
     procs = Integer.parseInt(args[0]); 
     n = Integer.parseInt(args[1]); 
     if (args.length > 2) granularity = Integer.parseInt(args[2]); 
     } 

     catch (Exception e) { 
     System.out.println(usage); 
     return; 
     } 

     if (((n & (n - 1)) != 0) || 
      ((granularity & (granularity - 1)) != 0) || 
      granularity < 2) { 
     System.out.println(usage); 
     return; 
     } 

     float[][] a = new float[n][n]; 
     float[][] b = new float[n][n]; 
     float[][] c = new float[n][n]; 
     init(a, b, n); 

     FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs); 
     g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n)); 
     g.stats(); 

     // check(c, n); 
    } 
    catch (InterruptedException ex) {} 
    } 


    // To simplify checking, fill with all 1's. Answer should be all n's. 
    static void init(float[][] a, float[][] b, int n) { 
    for (int i = 0; i < n; ++i) { 
     for (int j = 0; j < n; ++j) { 
     a[i][j] = 1.0F; 
     b[i][j] = 1.0F; 
     } 
    } 
    } 

    static void check(float[][] c, int n) { 
    for (int i = 0; i < n; i++) { 
     for (int j = 0; j < n; j++) { 
     if (c[i][j] != n) { 
      throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]); 
     } 
     } 
    } 
    } 

    /** 
    * Multiply matrices AxB by dividing into quadrants, using algorithm: 
    * <pre> 
    *  A  x  B        
    * 
    * A11 | A12  B11 | B12  A11*B11 | A11*B12  A12*B21 | A12*B22 
    * |----+----| x |----+----| = |--------+--------| + |---------+-------| 
    * A21 | A22  B21 | B21  A21*B11 | A21*B21  A22*B21 | A22*B22 
    * </pre> 
    */ 


    static class Multiplier extends FJTask { 
    final float[][] A; // Matrix A 
    final int aRow;  // first row of current quadrant of A 
    final int aCol;  // first column of current quadrant of A 

    final float[][] B; // Similarly for B 
    final int bRow; 
    final int bCol; 

    final float[][] C; // Similarly for result matrix C 
    final int cRow; 
    final int cCol; 

    final int size;  // number of elements in current quadrant 

    Multiplier(float[][] A, int aRow, int aCol, 
       float[][] B, int bRow, int bCol, 
       float[][] C, int cRow, int cCol, 
       int size) { 
     this.A = A; this.aRow = aRow; this.aCol = aCol; 
     this.B = B; this.bRow = bRow; this.bCol = bCol; 
     this.C = C; this.cRow = cRow; this.cCol = cCol; 
     this.size = size; 
    } 

    public void run() { 

     if (size <= granularity) { 
     multiplyStride2(); 
     } 

     else { 
     int h = size/2; 

     coInvoke(new FJTask[] { 
      seq(new Multiplier(A, aRow, aCol, // A11 
          B, bRow, bCol, // B11 
          C, cRow, cCol, // C11 
          h), 
       new Multiplier(A, aRow, aCol+h, // A12 
          B, bRow+h, bCol, // B21 
          C, cRow, cCol, // C11 
          h)), 

      seq(new Multiplier(A, aRow, aCol, // A11 
          B, bRow, bCol+h, // B12 
          C, cRow, cCol+h, // C12 
          h), 
       new Multiplier(A, aRow, aCol+h, // A12 
          B, bRow+h, bCol+h, // B22 
          C, cRow, cCol+h, // C12 
          h)), 

      seq(new Multiplier(A, aRow+h, aCol, // A21 
          B, bRow, bCol, // B11 
          C, cRow+h, cCol, // C21 
          h), 
       new Multiplier(A, aRow+h, aCol+h, // A22 
          B, bRow+h, bCol, // B21 
          C, cRow+h, cCol, // C21 
          h)), 

      seq(new Multiplier(A, aRow+h, aCol, // A21 
          B, bRow, bCol+h, // B12 
          C, cRow+h, cCol+h, // C22 
          h), 
       new Multiplier(A, aRow+h, aCol+h, // A22 
          B, bRow+h, bCol+h, // B22 
          C, cRow+h, cCol+h, // C22 
          h)) 
     }); 
     } 
    } 

    /** 
    * Version of matrix multiplication that steps 2 rows and columns 
    * at a time. Adapted from Cilk demos. 
    * Note that the results are added into C, not just set into C. 
    * This works well here because Java array elements 
    * are created with all zero values. 
    **/ 

    void multiplyStride2() { 
     for (int j = 0; j < size; j+=2) { 
     for (int i = 0; i < size; i +=2) { 

      float[] a0 = A[aRow+i]; 
      float[] a1 = A[aRow+i+1]; 

      float s00 = 0.0F; 
      float s01 = 0.0F; 
      float s10 = 0.0F; 
      float s11 = 0.0F; 

      for (int k = 0; k < size; k+=2) { 

      float[] b0 = B[bRow+k]; 

      s00 += a0[aCol+k] * b0[bCol+j]; 
      s10 += a1[aCol+k] * b0[bCol+j]; 
      s01 += a0[aCol+k] * b0[bCol+j+1]; 
      s11 += a1[aCol+k] * b0[bCol+j+1]; 

      float[] b1 = B[bRow+k+1]; 

      s00 += a0[aCol+k+1] * b1[bCol+j]; 
      s10 += a1[aCol+k+1] * b1[bCol+j]; 
      s01 += a0[aCol+k+1] * b1[bCol+j+1]; 
      s11 += a1[aCol+k+1] * b1[bCol+j+1]; 
      } 

      C[cRow+i] [cCol+j] += s00; 
      C[cRow+i] [cCol+j+1] += s01; 
      C[cRow+i+1][cCol+j] += s10; 
      C[cRow+i+1][cCol+j+1] += s11; 
     } 
     } 
    } 

    } 

} 

該代碼可用於舊版本的叉的書面/ join框架:

我Doug Lea的網站()下載下面的例子。所以我必須重寫它。我重寫的代碼實現我自己的接口,看起來像這樣:

public class Java7MatrixMultiply implements Algorithm { 
    private static final int SIZE = 32; 
    private static final int THRESHOLD = 8; 

    private float[][] a = new float[SIZE][SIZE]; 
    private float[][] b = new float[SIZE][SIZE]; 
    private float[][] c = new float[SIZE][SIZE]; 

    ForkJoinPool forkJoinPool; 

    @Override 
    public void initialize() { 
     init(a, b, SIZE); 
    } 

    @Override 
    public void execute() { 
     MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE); 
     forkJoinPool = new ForkJoinPool(); 
     forkJoinPool.invoke(mainTask); 

     System.out.println("Terminated!"); 
    } 

    @Override 
    public void printResult() { 
     check(c, SIZE); 

     for (int i = 0; i < SIZE; i++) { 
      for (int j = 0; j < SIZE; j++) { 
       System.out.print(c[i][j] + " "); 
      } 

      System.out.println(); 
     } 
    } 

    // To simplify checking, fill with all 1's. Answer should be all n's. 
    static void init(float[][] a, float[][] b, int n) { 
     for (int i = 0; i < n; ++i) { 
      for (int j = 0; j < n; ++j) { 
       a[i][j] = 1.0F; 
       b[i][j] = 1.0F; 
      } 
     } 
    } 

    static void check(float[][] c, int n) { 
     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       if (c[i][j] != n) { 
        //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
        System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
       } 
      } 
     } 
    } 

    private class MatrixMultiplyTask extends RecursiveAction { 
     private final float[][] A; // Matrix A 
     private final int aRow; // first row of current quadrant of A 
     private final int aCol; // first column of current quadrant of A 

     private final float[][] B; // Similarly for B 
     private final int bRow; 
     private final int bCol; 

     private final float[][] C; // Similarly for result matrix C 
     private final int cRow; 
     private final int cCol; 

     private final int size; 

     MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B, 
       int bRow, int bCol, float[][] C, int cRow, int cCol, int size) { 
      this.A = A; 
      this.aRow = aRow; 
      this.aCol = aCol; 
      this.B = B; 
      this.bRow = bRow; 
      this.bCol = bCol; 
      this.C = C; 
      this.cRow = cRow; 
      this.cCol = cCol; 
      this.size = size; 
     } 

     @Override 
     protected void compute() {  
      if (size <= THRESHOLD) { 
       multiplyStride2(); 
      } else { 

       int h = size/2;    

       invokeAll(new MatrixMultiplyTask[] { 
         new MatrixMultiplyTask(A, aRow, aCol, // A11 
           B, bRow, bCol, // B11 
           C, cRow, cCol, // C11 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol + h, // A12 
           B, bRow + h, bCol, // B21 
           C, cRow, cCol, // C11 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol, // A11 
           B, bRow, bCol + h, // B12 
           C, cRow, cCol + h, // C12 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol + h, // A12 
           B, bRow + h, bCol + h, // B22 
           C, cRow, cCol + h, // C12 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol, // A21 
           B, bRow, bCol, // B11 
           C, cRow + h, cCol, // C21 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 
           B, bRow + h, bCol, // B21 
           C, cRow + h, cCol, // C21 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol, // A21 
           B, bRow, bCol + h, // B12 
           C, cRow + h, cCol + h, // C22 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 
           B, bRow + h, bCol + h, // B22 
           C, cRow + h, cCol + h, // C22 
           h) }); 

      } 
     } 

     /** 
     * Version of matrix multiplication that steps 2 rows and columns at a 
     * time. Adapted from Cilk demos. Note that the results are added into 
     * C, not just set into C. This works well here because Java array 
     * elements are created with all zero values. 
     **/ 

     void multiplyStride2() { 
      for (int j = 0; j < size; j += 2) { 
       for (int i = 0; i < size; i += 2) { 

        float[] a0 = A[aRow + i]; 
        float[] a1 = A[aRow + i + 1]; 

        float s00 = 0.0F; 
        float s01 = 0.0F; 
        float s10 = 0.0F; 
        float s11 = 0.0F; 

        for (int k = 0; k < size; k += 2) { 

         float[] b0 = B[bRow + k]; 

         s00 += a0[aCol + k] * b0[bCol + j]; 
         s10 += a1[aCol + k] * b0[bCol + j]; 
         s01 += a0[aCol + k] * b0[bCol + j + 1]; 
         s11 += a1[aCol + k] * b0[bCol + j + 1]; 

         float[] b1 = B[bRow + k + 1]; 

         s00 += a0[aCol + k + 1] * b1[bCol + j]; 
         s10 += a1[aCol + k + 1] * b1[bCol + j]; 
         s01 += a0[aCol + k + 1] * b1[bCol + j + 1]; 
         s11 += a1[aCol + k + 1] * b1[bCol + j + 1]; 
        } 

        C[cRow + i][cCol + j] += s00; 
        C[cRow + i][cCol + j + 1] += s01; 
        C[cRow + i + 1][cCol + j] += s10; 
        C[cRow + i + 1][cCol + j + 1] += s11; 
       } 
      } 
     } 
    } 
} 

有時我的計算未能通過檢查。矩陣的某些區域具有與預期不同的值。這些不一致是隨機的,並不總是會發生。我懷疑計算方法出了問題,因爲我不得不重寫使用Seq類的部分。 Seq klass按順序執行任務,與invokeAll()方法不同。在當前版本的fork/join框架中,該類不再存在。我對矩陣乘法算法不是很熟悉,所以很難看出錯誤。有什麼建議麼?

回答

0

正如您已經注意到的那樣,屬於同一象限的子任務的順序執行對於此算法很重要。因此,您需要實現您自己的seq()函數,例如,如下所示,並將其用作原始代碼:

public ForkJoinTask<?> seq(final ForkJoinTask<?> a, final ForkJoinTask<?> b) { 
    return adapt(new Runnable() { 
     public void run() { 
      a.invoke(); 
      b.invoke(); 
     } 
    }); 
} 
+0

謝謝。它完美無瑕。 – TheArchitect 2011-03-29 13:56:02

1

您正在積累C[cRow + i][cCol + j] += s00;等的結果。這不是線程安全操作,因此您必須同步行或確保只有一個任務更新單元。沒有這個,你會看到隨機單元格設置不正確。

我會檢查你找到正確答案,用1

BTW併發:float可能不是這裏最好的選擇。它的精度數字相當低,而且在繁重的矩陣操作中(我假設你正在做或者沒有多少點使用多個線程),舍入誤差可能會佔用大部分或全部的精度。我建議改爲考慮double

例如float大約有7位數字的精度,一個經驗法則是錯誤與計算次數成正比。因此,對於1K x 1K矩陣,您可能會有4位數字的精度。對於10K x 10K,最多隻能有三個。 double具有16位數字的精度,這意味着在10K x 10K加密後可能有12位精度。

+0

感謝您的快速回復。我測試了併發性爲1的算法。沒有發生錯誤。你的解釋當然是對的,但使用同步並不是很有效。 Dough Lea在原始代碼中也沒有使用這種方法。是否有可能重新實現計算方法,因此不需要同步? – TheArchitect 2011-03-29 13:12:13