2017-02-14 16 views
4

我試圖將功能轉換爲AVX版本。函數本身基本上只是比較float和返回true/false取決於計算。AVX版本沒有預期的那麼快

這裏是原來的功能:

bool testSingle(float* thisFloat, float* otherFloat) 
{ 
    for (unsigned int k = 0; k < COL_COUNT/2; k++) 
    { 

     if (thisFloat[k] < -otherFloat[COL_COUNT/2 + k] || -thisFloat[COL_COUNT/2 + k] > otherFloat[k]) 
     { 
      return true; 
     } 
    } 

    return false; 
} 

而且,這是AVX版

__m256 testAVX(float* thisFloat, __m256* otherFloatInAVX) 
{ 
    __m256 vTemp1; 
    __m256 vTemp2; 
    __m256 vTempResult; 
    __m256 vEndResult = _mm256_set1_ps(0.0f); 

    for (unsigned int k = 0; k < COL_COUNT/2; k++) 
    { 

     vTemp1 = _mm256_cmp_ps(_mm256_set1_ps(thisFloat[k]), otherFloatInAVX[COL_COUNT/2 + k], _CMP_LT_OQ); 

     vTemp2 = _mm256_cmp_ps(_mm256_set1_ps(-thisFloat[COL_COUNT/2 + k]), otherFloatInAVX[k], _CMP_GT_OQ); 

     vTempResult = _mm256_or_ps(vTemp1, vTemp2); 
     vEndResult = _mm256_or_ps(vTempResult, vEndResult); 
     if (_mm256_movemask_ps(vEndResult) == 255) 
     { 
      break; 
     } 

    } 

    return vEndResult; 

} 

這裏是完整的代碼。我在開始時生成了一些隨機的浮點數,並將其保存到AVX中以便在AVX版本中進行計算。變量thisFloat中的值將與otherFloat1,otherFloat2,...,otherFloat8進行比較。

#define ROW_COUNT 1000000 
#define COL_COUNT 46 

float randomNumberFloat(float Min, float Max) 
{ 
    return ((float(rand())/float(RAND_MAX)) * (Max - Min)) + Min; 
} 

int main(int argc, char** argv) 
{ 

    float** thisFloat = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     thisFloat[i] = new float[COL_COUNT]; 

    float** otherFloat1 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat1[i] = new float[COL_COUNT]; 

    float** otherFloat2 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat2[i] = new float[COL_COUNT]; 

    float** otherFloat3 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat3[i] = new float[COL_COUNT]; 

    float** otherFloat4 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat4[i] = new float[COL_COUNT]; 

    float** otherFloat5 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat5[i] = new float[COL_COUNT]; 

    float** otherFloat6 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat6[i] = new float[COL_COUNT]; 

    float** otherFloat7 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat7[i] = new float[COL_COUNT]; 

    float** otherFloat8 = new float*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloat8[i] = new float[COL_COUNT]; 

    // save to AVX 
    __m256** otherFloatInAVX = new __m256*[ROW_COUNT]; 
    for (int i = 0; i < ROW_COUNT; ++i) 
     otherFloatInAVX[i] = new __m256[COL_COUNT]; 

    // variable for results 
    unsigned int* resultsSingle = new unsigned int[ROW_COUNT]; 
    __m256* resultsAVX = new __m256[ROW_COUNT]; 


    // Generate Random Values 
    for (unsigned int i = 0; i < ROW_COUNT; i++) 
    { 
     for (unsigned int j = 0; j < COL_COUNT; j++) 
     { 
      thisFloat[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat1[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat2[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat3[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat4[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat5[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat6[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat7[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 
      otherFloat8[i][j] = randomNumberFloat(-1000.0f, 1000.0f); 

     } 

     for (unsigned int j = 0; j < COL_COUNT/2; j++) 
     { 
      otherFloatInAVX[i][j] = _mm256_setr_ps(otherFloat1[i][j], otherFloat2[i][j], otherFloat3[i][j], otherFloat4[i][j], otherFloat5[i][j], otherFloat6[i][j], otherFloat7[i][j], otherFloat8[i][j]); 
      otherFloatInAVX[i][COL_COUNT/2 + j] = _mm256_setr_ps(-otherFloat1[i][j], -otherFloat2[i][j], -otherFloat3[i][j], -otherFloat4[i][j], -otherFloat5[i][j], -otherFloat6[i][j], -otherFloat7[i][j], -otherFloat8[i][j]); 
     } 
    } 

    // do normal test 
    auto start_normal = std::chrono::high_resolution_clock::now(); 
    for (unsigned int i = 0; i < ROW_COUNT; i++) 
    { 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat1[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat2[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat3[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat4[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat5[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat6[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat7[i]); 
     resultsSingle[i] = testSingle(thisFloat[i], otherFloat8[i]); 
    } 
    auto end_normal = std::chrono::high_resolution_clock::now(); 

    auto duration_normal = std::chrono::duration_cast<std::chrono::milliseconds>(end_normal - start_normal); 
    std::cout << "Duration of normal test: " << duration_normal.count() << " ms \n"; 

    // do AVX test 

    auto start_avx = std::chrono::high_resolution_clock::now(); 
    for (unsigned int i = 0; i < ROW_COUNT; i++) 
    { 
     resultsAVX[i] = testAVX(thisFloat[i], otherFloatInAVX[i]); 
    } 
    auto end_avx = std::chrono::high_resolution_clock::now(); 


    auto duration_avx = std::chrono::duration_cast<std::chrono::milliseconds>(end_avx - start_avx); 
    std::cout << "Duration of AVX test: " << duration_avx.count() << " ms"; 
return 0; 
} 

然後,我測兩者的運行時間,並得到

Duration of normal test: 290 ms 
Duration of AVX test: 159 ms 

的AVX版本是1.82x速度比原來的一個。

是否仍有可能改進AVX版本?或者我以錯誤的方式做了AVX?由於我同時進行了8次計算,因此我預計它可能會快5到6倍。

+0

如果你只是陷入內存/高速緩存帶寬,你有沒有切入? – Elalfer

+1

SIMD性能的關鍵是在數據寄存器中對數據進行大量計算,以幫助掩蓋從內存中加載大量數據的開銷。否則,你只是寫了一個非常複雜的''memcpy''。 –

+0

我認爲你的AVX例程有一個錯誤 - 在標量代碼相同的條件下它不會「早」出來 - 你需要將測試從'== 255'改爲'!= 0'。 (注意:這是早上的早上,我還沒有喝咖啡,但是經常檢查,這個*看起來像是一個bug。) –

回答

2

我認爲AVX版必須具有相同的API標量(所以我一點都改變了它):

bool testAVX(float * thisFloat, float * otherFloat) 
{ 
    size_t k = 0, size = COL_COUNT/2, sizeAligned = size/8 * 8; 

    __m256 zero = _mm256_set1_ps(0); 
    for (; k < sizeAligned; k += 8) 
    { 
     __m256 _thisFloat1 = _mm256_loadu_ps(thisFloat + k); 
     __m256 _thisFloat2 = _mm256_loadu_ps(thisFloat + k + size); 
     __m256 _otherFloat1 = _mm256_loadu_ps(otherFloat + k); 
     __m256 _otherFloat2 = _mm256_loadu_ps(otherFloat + k + size); 

     __m256 compareMask1 = _mm256_cmp_ps(_thisFloat1, _mm256_sub_ps(zero, _otherFloat2), _CMP_LT_OQ); 
     __m256 compareMask2 = _mm256_cmp_ps(_mm256_sub_ps(zero, _thisFloat2), _otherFloat1, _CMP_GT_OQ); 

     __m256 compareMask = _mm256_or_ps(compareMask1, compareMask2); 

     if (!_mm256_testz_ps(compareMask, compareMask)) 
      return true; 
    } 

    for (; k < size; k++) 
    { 
     if (thisFloat[k] < -otherFloat[size + k] || -thisFloat[size + k] > otherFloat[k]) 
      return true; 
    } 
    return false; 
} 

所以它會更容易在他們之間比較這些版本。

+1

@PaulR是的。你是對的。 – ErmIg

相關問題