2017-06-27 47 views
0

在下面的Fortran程序中,我使用Intel的MKL庫使用dgemm執行矩陣乘法。最初,我使用了matmul子程序,並得到了正確的結果。當我在下面的循環中將matmul翻譯爲dgemm時,我得到了所有零矢量,而不是正確的輸出。我感謝您的幫助。Fortran的MKL dgemm結果爲零

program spectral_norm  
implicit none 
! 
integer, parameter :: n = 5500, dp = kind(0.0d0) 
real(dp), allocatable :: A(:, :), u(:), v(:), Au(:), Av(:) 
integer :: i, j 

allocate(u(n), v(n), A(n, n), Au(n), Av(n)) 

do j = 1, n 
    do i = 1, n 
     A(i, j) = Ac(i, j) 
    end do 
end do 

u = 1 
do i = 1, 10 
    call dgemm('N','N', n, 1, n, 1.0, A, n, u, n, 0.0, Au, n) 
    call dgemm('N','N', n, 1, n, 1.0, Au, n, A, n, 0.0, v, n) 
    call dgemm('N','N', n, 1, n, 1.0, A, n, v, n, 0.0, Av, n) 
    call dgemm('N','N', n, 1, n, 1.0, Av, n, A, n, 0.0, u, n) 
    !v = matmul(matmul(A, u), A) 
    !u = matmul(matmul(A, v), A) 
end do 

write(*, "(f0.9)") sqrt(dot_product(u, v)/dot_product(v, v)) 

contains 

pure real(dp) function Ac(i, j) result(r) 
integer, intent(in) :: i, j 
r = 1._dp/((i+j-2) * (i+j-1)/2 + i) 
end function 

end program spectral_norm 

這給NaN,而從matmul輸出正確的是1.274224153

+1

使用MKL模塊。 Tjey可以幫助您識別許多錯誤,並且還包含簡單的Fortran 90接口以供這些子例程使用。 –

+0

@VladimirF - 我不認爲這是一個缺少模塊問題。我已經正確設置了VStudio中的include目錄,並設置了「使用MKL」選項。如果我試圖乘以任何其他兩個矩陣,它的工作原理。例如'call dgemm('t','n',n,n,n,1.0,A,n,A,n,0.0,A,n)給出正確的矩陣。 – AboAmmar

+2

我想你應該使用1.d0和0.d0,因爲你使用的是需要雙精度參數的dgemm。 'dgemm('N','N',n,1,n,1.d0,A,n,u,n,0.d0,Au,n)' –

回答

0

那麼,謝謝大家的建議。我想我找出了錯誤的根源。乘法的順序在兩種情況下相反,它應該是A * AuA * Av。這是因爲A的訂單爲n x n,而且AuAv的訂單爲n x 1。所以,由於尺寸不匹配,我們不能乘以Au * AAv * A。我發佈了下面的更正版本。

program spectral_norm  
implicit none 
! 
integer, parameter :: n = 5500, dp = kind(0.d0) 
real(dp), allocatable :: A(:,:), u(:), v(:), Au(:), Av(:) 
integer :: i, j 

allocate(u(n), v(n), A(n, n), Au(n), Av(n)) 

do j = 1, n 
    do i = 1, n 
     A(i, j) = Ac(i, j) 
    end do 
end do 

u = 1 
do i = 1, 10 
    call dgemm('N', 'N', n, 1, n, 1._dp, A, n, u, n, 0._dp, Au, n) 
    call dgemm('T', 'N', n, 1, n, 1._dp, A, n, Au, n, 0._dp, v, n) 
    call dgemm('N', 'N', n, 1, n, 1._dp, A, n, v, n, 0._dp, Av, n) 
    call dgemm('T', 'N', n, 1, n, 1._dp, A, n, Av, n, 0._dp, u, n) 
end do 

write(*, "(f0.9)") sqrt(dot_product(u, v)/dot_product(v, v)) 

contains 

pure real(dp) function Ac(i, j) result(r) 
    integer, intent(in) :: i, j 
    r = 1._dp/((i+j-2) * (i+j-1)/2 + i) 
end function 

end program spectral_norm 

這給正確的結果:

1.274224153 
Elapsed time 0.5150000  seconds