2014-01-24 51 views
2

我有一個數組表示3D點之間的函數。因此,作爲索引它獲得6元組。現在我想對這個數組的元素應用一個函數,但是這個函數不僅依賴於元素的值,還依賴於它的索引。因此,如果A是矩陣,以及m和n是我們的3D點即A [M,N]存儲其值,而k是0和3然後f(A,k)[m,n]之間的值等於:基於索引函數的向量化數組操作

- m[k]**2如果m==n

- m[k]**2-n[k]**2否則

以下是我的代碼:

import numpy as np 
def func(a,k): 
    b=np.empty(a.shape) 
    for i in range(a.flatten().size): 
     ind=np.unravel_index(i,a.shape) 
     if ind[0:3]==ind[3:6]: 
      b[ind]=a[ind]*ind[0:3][k]**2 
     else: 
      b[ind]=a[ind]*(ind[0:3][k]**2-ind[3:6][k]**2) 
    return b 
a=np.arange(729).reshape((3,3,3,3,3,3)) 
print func(a,2) 

反正是有vecotrizing這個代碼?

P.S.這是我實際需要做的簡化版本。

回答

2

使用numpy.indices()創建索引數組,那麼你可以vecotrizing計算:

import numpy as np 
def func(a,k): 
    b=np.empty(a.shape) 
    for i in range(a.flatten().size): 
     ind=np.unravel_index(i,a.shape) 
     if ind[0:3]==ind[3:6]: 
      b[ind]=a[ind]*ind[0:3][k]**2 
     else: 
      b[ind]=a[ind]*(ind[0:3][k]**2-ind[3:6][k]**2) 
    return b 

def func2(a,k): 
    b = np.empty(a.shape) 
    ind = np.indices(a.shape).reshape(6, -1) 
    mask = np.all(ind[:3] == ind[3:6], axis=0) 
    ar = a.ravel() 
    br = b.ravel() 
    br[mask] = ar[mask]*ind[k, mask]**2 
    mask = ~mask 
    br[mask] = ar[mask]*(ind[k, mask]**2 - ind[3+k, mask]**2) 
    return b 

a = np.arange(729).reshape((3,3,3,3,3,3)) 
b1 = func(a, 2) 
b2 = func2(a, 2) 
np.allclose(b1, b2) 

這裏是%timeit結果:

%timeit func(a, 2) 
%timeit func2(a, 2) 

輸出:

100 loops, best of 3: 16.4 ms per loop 
1000 loops, best of 3: 579 µs per loop 

你可以爲您的情況優化一點:

def func3(a,k): 
    b = np.empty(a.shape) 
    ind = np.indices(a.shape).reshape(6, -1) 
    mask = ~np.all(ind[:3] == ind[3:6], axis=0) 
    ar = a.ravel() 
    br = b.ravel() 
    br[:] = ar*ind[k]**2 
    br[mask] -= ar[mask]*ind[3+k, mask]**2 
    return b