2015-03-31 323 views
0

我想通過使用寄存器(逐行信息)的矢量算法創建矩陣乘法。打開外循環4次 我有 無效matvec_XMM(雙*一,雙* X,雙* Y,INT N,INT磅) 函數返回壞的結果 問題據算法wchich我必須使用:C++矩陣乘法

i = 1,n,4 
    r0 = r1 = r2 = r3 = 0 
    j = 1,n,8 
     r0 = r0 + aij * xj + ai,j+1 * xj+1 + … + ai,j+7 * xj+7 
     r1 = r1 + ai+1,j * xj + ai+1,j+1 * xj+1 + … + ai+1,j+7 * xj+7 
     r2 = r2 + ai+2,j * xj + ai+2,j+1 * xj+1 + … + ai+2,j+7 * xj+7 
     r3 = r3 + ai+3,j * xj + ai+3,j+1 * xj+1 + … + ai+3,j+7 * xj+7 
    end j 
    yi = r0; yi+1 = r1; yi+2 = r2; yi+3 = r3; 
end i 

這是馬代碼:

#include "stdafx.h" 
#include <iostream> 
#include "mvec.h" 
#include <emmintrin.h> 

using namespace std; 

void mult_naive(double *a, double *x, double *y, int n) 
{ 
    int i, j, ij; 
    double register reg; 

    for(i=0, ij=0; i<n; ++i) 
    { 
     reg = 0; 

     for(j=0; j<n; ++j, ++ij) 
     { 
      reg += a[ij]*x[j]; 
     } 

     y[i] = reg; 
    } 
} 

void matvec_XMM(double* a, double* x, double* y, int n, int lb) 
{ 
int i, j; 

memset((void *)y, 0, n*sizeof(double)); 
double res0[2]; 
double res1[2]; 
double res2[2]; 
double res3[2]; 

__m128d ry0, ry1, ry2, ry3, ra0, rx0; 
double *ptr_a, *ptr_x, *ptr_y; 
const int nr = 4; 

ptr_a = a; 

for (i = 0; i < n; i+=nr) 
{ 
    ry0 = _mm_setzero_pd(); 
    ry1 = _mm_setzero_pd(); 
    ry2 = _mm_setzero_pd(); 
    ry3 = _mm_setzero_pd(); 

    ptr_y = &y[i]; 
    ptr_x = x; 

    for (j = 0; j<n; j+=lb) 
    { 

     _mm_prefetch((const char *)(ptr_a + lb*nr), _MM_HINT_NTA); 
     _mm_prefetch((const char *)(ptr_x + lb), _MM_HINT_T0); 

     //----1 
     rx0 = _mm_load_pd(ptr_x);  
     ra0 = _mm_load_pd(ptr_a); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry0 = _mm_add_pd(ry0, ra0); 

     ra0 = _mm_load_pd(ptr_a + 2); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry1 = _mm_add_pd(ry1, ra0); 

     ra0 = _mm_load_pd(ptr_a + 4); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry2 = _mm_add_pd(ry2, ra0); 

     ra0 = _mm_load_pd(ptr_a + 6); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry3 = _mm_add_pd(ry3, ra0); 

      //----2 
     rx0 = _mm_load_pd(ptr_x + 2); 
     ra0 = _mm_load_pd(ptr_a + 8); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry0 = _mm_add_pd(ry0, ra0); 

     ra0 = _mm_load_pd(ptr_a + 10); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry1 = _mm_add_pd(ry1, ra0); 

     ra0 = _mm_load_pd(ptr_a + 12); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry2 = _mm_add_pd(ry2, ra0); 

     ra0 = _mm_load_pd(ptr_a + 14); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry3 = _mm_add_pd(ry3, ra0); 

     //----3 
     rx0 = _mm_load_pd(ptr_x + 4);  
     ra0 = _mm_load_pd(ptr_a + 16); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry0 = _mm_add_pd(ry0, ra0); 

     ra0 = _mm_mul_pd(ra0, rx0); 
     ry1 = _mm_add_pd(ry1, ra0); 

     ra0 = _mm_load_pd(ptr_a + 20); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry2 = _mm_add_pd(ry2, ra0); 

     ra0 = _mm_load_pd(ptr_a + 22); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry3 = _mm_add_pd(ry3, ra0); 

     //----4 
     rx0 = _mm_load_pd(ptr_x + 6);  
     ra0 = _mm_load_pd(ptr_a + 24); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry0 = _mm_add_pd(ry0, ra0); 

     ra0 = _mm_load_pd(ptr_a + 26); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry1 = _mm_add_pd(ry1, ra0); 

     ra0 = _mm_load_pd(ptr_a + 28); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry2 = _mm_add_pd(ry2, ra0); 

     ra0 = _mm_load_pd(ptr_a + 30); 
     ra0 = _mm_mul_pd(ra0, rx0); 
     ry3 = _mm_add_pd(ry3, ra0); 

     ptr_a += lb*nr; 
     ptr_x += lb; 
    } 

    _mm_store_pd(res0, ry0); 
    *ptr_y = res0[0] + res0[1]; 

    _mm_store_pd(res1, ry1); 
    *(ptr_y + 1) = res1[0] + res1[1]; 

    _mm_store_pd(res2, ry2); 
    *(ptr_y + 2)= res2[0] + res2[1]; 

    _mm_store_pd(res3, ry3); 
    *(ptr_y + 3) = res3[0] + res3[1]; 

} 
} 



#include "stdafx.h" 
#include <iostream> 
#include <cmath> 
#include "windows.h" 
#include "mvec.h" 

using namespace std; 


int main(int argc, char* argv[]) 
{ 
    double *a, *x, *y, *z; 
    int n; 
    DWORD tstart; 
    const int lb = 8; 
    double elaps_time; 
    cout << "Program Mat_Vect: performance y = y +A*x\n"; 

#ifdef _DEBUG 
    cout << "DEBUG version\n"; 
#else 
    cout << "RELEASE version\n"; 
#endif 

    cout << "Input dimension\n"; 
    cin >> n; 

    n = n/lb; 
    n = lb*n; 

    try 
    { 
     a = new double [n*n]; 
     x = new double [n+1]; 
     y = new double [n]; 
     z = new double [n]; 
    } 
    catch(bad_alloc aa) 
    { 
     cout << "memory allocation error" << endl; 
     system("pause"); 
     exit(1); 
    } 

    memset((void *)a, 0, _msize((void *)a)); 
    memset((void *)x, 0, _msize((void *)x)); 
    memset((void *)y, 0, _msize((void *)y)); 

    cout << "start\n"; 

    prepare(a, x, n); 

    //-------------------------naive algorithm-----------------------// 
    cout << "naive algorithm: \n"; 
    tstart = GetTickCount(); 
    mult_naive(a, x, z, n); 
    elaps_time = (double)(GetTickCount()-tstart)/1000.0; 
    cout << "naive algorithm: " << elaps_time << " sec" << endl; 

    //-------------------------algorithm which uses XMM registers-----------------------// 
    delete [] a; 
    delete [] x; 
    a = (double *)_aligned_malloc(n*n*sizeof(double), 16); 
    x = (double *)_aligned_malloc(n*sizeof(double), 16); 
    if(!a || !x) 
    { 
     cout << "memory allocation error" << endl; 
     system("pause"); 
     exit(1); 
    } 
    cout << "algorithm which uses XMM: \n"; 
    prepare(a, x, n); 
    tstart = GetTickCount(); 
    matvec_XMM(a, x, y, n, lb); 
    elaps_time = (double)(GetTickCount()-tstart)/1000.0; 
    check(y, z, n); 
    cout << "algorithm which uses XMM: " << elaps_time << " sec" << endl; 

    delete [] y; 
    delete [] z; 
    _aligned_free(a); 
    _aligned_free(x); 

    system("pause"); 
    return 0; 
} 


void check(double *y, double *z, int n) 
{ 
    int i; 
    for(i=0; i<n; i++) 
    { 
     if(fabs(z[i] - y[i]) > 1.0e-9) 
     { 
      cout << "error\n"; 
      return; 
     } 
    } 

    cout << "OK\n"; 
} 

void prepare(double *a, double *x, int n) 
{ 
    int i, j, ij; 

    for(i=0, ij=0; i<n; i++) 
    { 
     for(j=0; j<n; j++, ij++) 
     { 
      if(i == j) 
       a[ij] = 10.0; 
      else 
       a[ij] = (double)(i+1); 
     } 

     x[i] = 1.0; 
    } 
} 
+2

他們的代碼,你的心不是展示足以再現這個問題,因爲你不顯示你如何稱呼這個功能,它不清楚你的意思是「壞的結果」 – user463035818 2015-03-31 11:00:15

+0

請更具體地說明你遇到的問題。參數'lb'的用途是什麼?你的目標是存儲第一個矩陣行 - 第二個列 - 專業? – Codor 2015-03-31 11:21:21

+0

thx爲您的答覆,我固定內容 – michal 2015-03-31 11:53:15

回答

-1

我找到了解決辦法

void matvec_XMM(double* a, double* x, double* y, int n, int lb) 
{ 
    int i, j; 

    memset((void *)y, 0, n*sizeof(double)); 
    __declspec(align(16)) double res0[2]; 
    __declspec(align(16)) double res1[2]; 
    __declspec(align(16)) double res2[2]; 
    __declspec(align(16)) double res3[2]; 

    __m128d ry0, ry1, ry2, ry3, ra0, ra1, ra2, ra3, rx0; 
    double *ptr_a1, *ptr_a2, *ptr_a3, *ptr_a4, *ptr_x, *ptr_y; 
    const int nr = 4; 

    for (i = 0; i < n; i+=nr) 
    { 
     ry0 = _mm_setzero_pd(); 
     ry1 = _mm_setzero_pd(); 
     ry2 = _mm_setzero_pd(); 
     ry3 = _mm_setzero_pd(); 

     ptr_y = &y[i]; 

     for (j = 0; j<n; j+=lb) 
     { 
      ptr_a1 = &a[i * n + j]; 
      ptr_a2 = &a[(i + 1) * n + j]; 
      ptr_a3 = &a[(i + 2) * n + j]; 
      ptr_a4 = &a[(i + 3) * n + j]; 
      ptr_x = &x[j]; 

      _mm_prefetch((const char *)(ptr_a1 + lb), _MM_HINT_NTA); 
      _mm_prefetch((const char *)(ptr_a2 + lb), _MM_HINT_NTA); 
      _mm_prefetch((const char *)(ptr_a3 + lb), _MM_HINT_NTA); 
      _mm_prefetch((const char *)(ptr_a4 + lb), _MM_HINT_NTA); 
      _mm_prefetch((const char *)(ptr_x + lb), _MM_HINT_T0); 

      //-------------------------------------1 
      rx0 = _mm_load_pd(ptr_x); 

      ra0 = _mm_load_pd(ptr_a1); 
      ra1 = _mm_load_pd(ptr_a2); 
      ra2 = _mm_load_pd(ptr_a3); 
      ra3 = _mm_load_pd(ptr_a4); 

      ra0 = _mm_mul_pd(ra0, rx0); 
      ra1 = _mm_mul_pd(ra1, rx0); 
      ra2 = _mm_mul_pd(ra2, rx0); 
      ra3 = _mm_mul_pd(ra3, rx0); 

      ry0 = _mm_add_pd(ry0, ra0); 
      ry1 = _mm_add_pd(ry1, ra1); 
      ry2 = _mm_add_pd(ry2, ra2); 
      ry3 = _mm_add_pd(ry3, ra3); 

      //------------------------------------2 
      rx0 = _mm_load_pd(ptr_x + 2); 

      ra0 = _mm_load_pd(ptr_a1 + 2); 
      ra1 = _mm_load_pd(ptr_a2 + 2); 
      ra2 = _mm_load_pd(ptr_a3 + 2); 
      ra3 = _mm_load_pd(ptr_a4 + 2); 

      ra0 = _mm_mul_pd(ra0, rx0); 
      ra1 = _mm_mul_pd(ra1, rx0); 
      ra2 = _mm_mul_pd(ra2, rx0); 
      ra3 = _mm_mul_pd(ra3, rx0); 

      ry0 = _mm_add_pd(ry0, ra0); 
      ry1 = _mm_add_pd(ry1, ra1); 
      ry2 = _mm_add_pd(ry2, ra2); 
      ry3 = _mm_add_pd(ry3, ra3); 

      //-----------------------------------3 
      rx0 = _mm_load_pd(ptr_x + 4); 

      ra0 = _mm_load_pd(ptr_a1 + 4); 
      ra1 = _mm_load_pd(ptr_a2 + 4); 
      ra2 = _mm_load_pd(ptr_a3 + 4); 
      ra3 = _mm_load_pd(ptr_a4 + 4); 

      ra0 = _mm_mul_pd(ra0, rx0); 
      ra1 = _mm_mul_pd(ra1, rx0); 
      ra2 = _mm_mul_pd(ra2, rx0); 
      ra3 = _mm_mul_pd(ra3, rx0); 

      ry0 = _mm_add_pd(ry0, ra0); 
      ry1 = _mm_add_pd(ry1, ra1); 
      ry2 = _mm_add_pd(ry2, ra2); 
      ry3 = _mm_add_pd(ry3, ra3); 

      //----------------------------------4 
      rx0 = _mm_load_pd(ptr_x + 6); 

      ra0 = _mm_load_pd(ptr_a1 + 6); 
      ra1 = _mm_load_pd(ptr_a2 + 6); 
      ra2 = _mm_load_pd(ptr_a3 + 6); 
      ra3 = _mm_load_pd(ptr_a4 + 6); 

      ra0 = _mm_mul_pd(ra0, rx0); 
      ra1 = _mm_mul_pd(ra1, rx0); 
      ra2 = _mm_mul_pd(ra2, rx0); 
      ra3 = _mm_mul_pd(ra3, rx0); 

      ry0 = _mm_add_pd(ry0, ra0); 
      ry1 = _mm_add_pd(ry1, ra1); 
      ry2 = _mm_add_pd(ry2, ra2); 
      ry3 = _mm_add_pd(ry3, ra3); 
     } 

     _mm_store_pd(res0, ry0); 
     *ptr_y = res0[0] + res0[1]; 

     _mm_store_pd(res1, ry1); 
     *(ptr_y + 1) = res1[0] + res1[1]; 

     _mm_store_pd(res2, ry2); 
     *(ptr_y + 2) = res2[0] + res2[1]; 

     _mm_store_pd(res3, ry3); 
     *(ptr_y + 3) = res3[0] + res3[1]; 
    } 
}