這裏必須處理類似的問題:Calling BLAS/LAPACK directly using the SciPy interface and Cython但是不同,因爲我使用SciPy示例中的實際代碼_test_dgemm
:https://github.com/scipy/scipy/blob/master/scipy/linalg/cython_blas.pyx這是非常快(當輸入矩陣輸入時,比輸入矩陣快numpy.dot
,比輸入矩陣輸入快5倍,否則快20倍)。如果傳遞Mx1 1xN向量,則不會產生結果。它產生與傳遞矩陣的numpy.dot
相同的值。爲了清晰起見,我沒有發佈任何答案,所以我最小化了代碼。下面是dgemm.pyx.
:使用Cython中的Scipy cython_blas接口不處理向量Mx1 1xN
import numpy as np
cimport numpy as np
from scipy.linalg.cython_blas cimport dgemm
from cython cimport boundscheck
@boundscheck(False)
cpdef int fast_dgemm(double[:,::1] a, double[:,::1] b, double[:,::1] c, double alpha=1.0, double beta=0.0) nogil except -1:
cdef:
char *transa = 'n'
char *transb = 'n'
int m, n, k, lda, ldb, ldc
double *a0=&a[0,0]
double *b0=&b[0,0]
double *c0=&c[0,0]
ldb = (&a[1,0]) - a0 if a.shape[0] > 1 else 1
lda = (&b[1,0]) - b0 if b.shape[0] > 1 else 1
k = b.shape[0]
if k != a.shape[1]:
with gil:
raise ValueError("Shape mismatch in input arrays.")
m = b.shape[1]
n = a.shape[0]
if n != c.shape[0] or m != c.shape[1]:
with gil:
raise ValueError("Output array does not have the correct shape.")
ldc = (&c[1,0]) - c0 if c.shape[0] > 1 else 1
dgemm(transa, transb, &m, &n, &k, &alpha, b0, &lda, a0,
&ldb, &beta, c0, &ldc)
return 0
這裏有一個樣本測試腳本:
import numpy as np;
a=np.random.randn(1000);
b=np.random.randn(1000);
a.resize(len(a),1);
a=np.array(a, order='c');
b.resize(1,len(b));
b=np.array(b, order='c');
c = np.empty((a.shape[0],b.shape[1]), float, order='c');
from dgemm import _test_dgemm;
_test_dgemm(a,b,c);
如果你想在Windows上發揮它與Python 3.5的x64這裏的setup.py
通過命令提示符下鍵入python setup.py build_ext --inplace --compiler=msvc
來構建它
from Cython.Distutils import build_ext
import numpy as np
import os
try:
from setuptools import setup
from setuptools import Extension
except ImportError:
from distutils.core import setup
from distutils.extension import Extension
module = 'dgemm'
ext_modules = [Extension(module, sources=[module + '.pyx'],
include_dirs=['C://Program Files (x86)//Windows Kits//10//Include//10.0.10240.0//ucrt','C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//include','C://Program Files (x86)//Windows Kits//8.1//Include//shared'],
library_dirs=['C://Program Files (x86)//Windows Kits//8.1//bin//x64', 'C://Windows//System32', 'C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//lib//amd64', 'C://Program Files (x86)//Windows Kits//8.1//Lib//winv6.3//um//x64', 'C://Program Files (x86)//Windows Kits//10//Lib//10.0.10240.0//ucrt//x64'],
extra_compile_args=['/Ot', '/favor:INTEL64', '/EHsc', '/GA'],
language='c++')]
setup(
name = module,
ext_modules = ext_modules,
cmdclass = {'build_ext': build_ext},
include_dirs = [np.get_include(), os.path.join(np.get_include(), 'numpy')]
)
任何幫助非常感謝!
謝謝你 - 爲載體工作。我發佈了一個編輯來匹配語法。 C格式是故意的,因爲它是從NumPy調用的,儘管底層庫是Fortran。很好的解釋。 – Matt