2017-07-07 74 views
1

這裏必須處理類似的問題:Calling BLAS/LAPACK directly using the SciPy interface and Cython但是不同,因爲我使用SciPy示例中的實際代碼_test_dgemmhttps://github.com/scipy/scipy/blob/master/scipy/linalg/cython_blas.pyx這是非常快(當輸入矩陣輸入時,比輸入矩陣快numpy.dot,比輸入矩陣輸入快5倍,否則快20倍)。如果傳遞Mx1 1xN向量,則不會產生結果。它產生與傳遞矩陣的numpy.dot相同的值。爲了清晰起見,我沒有發佈任何答案,所以我最小化了代碼。下面是dgemm.pyx.使用Cython中的Scipy cython_blas接口不處理向量Mx1 1xN

import numpy as np 
cimport numpy as np 
from scipy.linalg.cython_blas cimport dgemm 
from cython cimport boundscheck 

@boundscheck(False) 
cpdef int fast_dgemm(double[:,::1] a, double[:,::1] b, double[:,::1] c, double alpha=1.0, double beta=0.0) nogil except -1: 

    cdef: 
     char *transa = 'n' 
     char *transb = 'n' 
     int m, n, k, lda, ldb, ldc 
     double *a0=&a[0,0] 
     double *b0=&b[0,0] 
     double *c0=&c[0,0] 

    ldb = (&a[1,0]) - a0 if a.shape[0] > 1 else 1 
    lda = (&b[1,0]) - b0 if b.shape[0] > 1 else 1 

    k = b.shape[0] 
    if k != a.shape[1]: 
     with gil: 
      raise ValueError("Shape mismatch in input arrays.") 
    m = b.shape[1] 
    n = a.shape[0] 
    if n != c.shape[0] or m != c.shape[1]: 
     with gil: 
      raise ValueError("Output array does not have the correct shape.") 
    ldc = (&c[1,0]) - c0 if c.shape[0] > 1 else 1 
    dgemm(transa, transb, &m, &n, &k, &alpha, b0, &lda, a0, 
       &ldb, &beta, c0, &ldc) 
    return 0 

這裏有一個樣本測試腳本:

import numpy as np; 

a=np.random.randn(1000); 
b=np.random.randn(1000); 
a.resize(len(a),1); 
a=np.array(a, order='c'); 
b.resize(1,len(b)); 
b=np.array(b, order='c'); 
c = np.empty((a.shape[0],b.shape[1]), float, order='c'); 

from dgemm import _test_dgemm; 
_test_dgemm(a,b,c); 

如果你想在Windows上發揮它與Python 3.5的x64這裏的setup.py通過命令提示符下鍵入python setup.py build_ext --inplace --compiler=msvc來構建它

from Cython.Distutils import build_ext 
import numpy as np 
import os 

try: 
    from setuptools import setup 
    from setuptools import Extension 
except ImportError: 
    from distutils.core import setup 
    from distutils.extension import Extension 

module = 'dgemm' 

ext_modules = [Extension(module, sources=[module + '.pyx'], 
       include_dirs=['C://Program Files (x86)//Windows Kits//10//Include//10.0.10240.0//ucrt','C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//include','C://Program Files (x86)//Windows Kits//8.1//Include//shared'], 
       library_dirs=['C://Program Files (x86)//Windows Kits//8.1//bin//x64', 'C://Windows//System32', 'C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//lib//amd64', 'C://Program Files (x86)//Windows Kits//8.1//Lib//winv6.3//um//x64', 'C://Program Files (x86)//Windows Kits//10//Lib//10.0.10240.0//ucrt//x64'], 
       extra_compile_args=['/Ot', '/favor:INTEL64', '/EHsc', '/GA'], 
       language='c++')] 

setup(
    name = module, 
    ext_modules = ext_modules, 
    cmdclass = {'build_ext': build_ext}, 
    include_dirs = [np.get_include(), os.path.join(np.get_include(), 'numpy')] 
    ) 

任何幫助非常感謝!

回答

1

如果我看到它是正確的,你可以嘗試使用帶有c-memory-layout的fortran-routines。

即使您明顯知道,我想先詳細說明行主順序(c-memory-layout)和列主順序(fortran-memory-layout),按順序推斷我的答案。

因此,如果我們有一個2x3矩陣(即2行3列)A,並將其存儲在一些連續的內存,我們得到:

row-major-order(A) = A11, A12, A13, A21, A22, A23 
col-major-order(A) = A11, A21, A12, A22, A13, A33 

這意味着,如果我們得到一個連續的內存,它代表了矩陣的行主要順序,並將其解釋爲列主要順序的矩陣,我們將得到完全不同的矩陣!

但是,我們大家一起來看看置矩陣A^t我們不難看出:

row-major-order(A) = col-major-order(A^t) 
col-major-order(A) = row-major-order(A^t) 

這意味着,如果我們想獲得的矩陣C以行優先順序作爲結果, blas-routine應該將列轉換矩陣C寫入這個非常大的內存中(在所有這些我們不能改變的地方)。但是,C^t=(AB)^t=B^t*A^tB^tA^t是按列主要順序重新解釋的原始矩陣。

現在,讓An x k - 矩陣和B一個k x m - 矩陣,dgemm程序的調用應該如下:

dgemm(transa, transb, &m, &n, &k, &alpha, b0, &m, a0, &k, &beta, c0, &m) 

正如你所看到的,你在交換一些nm您碼。

+0

謝謝你 - 爲載體工作。我發佈了一個編輯來匹配語法。 C格式是故意的,因爲它是從NumPy調用的,儘管底層庫是Fortran。很好的解釋。 – Matt