2015-07-12 81 views
1

我寫了一個函數,它接受一組隨機笛卡爾座標並返回保留在某個空間域內的子集。爲了說明:使用數組索引在3D數組上應用2D數組函數

grid = np.ones((5,5)) 
grid = np.lib.pad(grid, ((10,10), (10,10)), 'constant') 

>> np.shape(grid) 
(25, 25) 

random_pts = np.random.random(size=(100, 2)) * len(grid) 

def inside(input): 
    idx = np.floor(input).astype(np.int) 
    mask = grid[idx[:,0], idx[:,1]] == 1 
    return input[mask] 

>> inside(random_pts) 
array([[ 10.59441506, 11.37998288], 
     [ 10.39124766, 13.27615815], 
     [ 12.28225713, 10.6970708 ], 
     [ 13.78351949, 12.9933591 ]]) 

但現在我想同時產生n套random_pts,並保持滿足同樣功能的條件n相應的子集的能力。所以,如果n=3

random_pts = np.random.random(size=(3, 100, 2)) * len(grid) 

沒有求助於for循環,我怎麼可能指數我的變量,使得inside(random_pts)回報像

array([[[ 17.73323523, 9.81956681], 
     [ 10.97074592, 2.19671642], 
     [ 21.12081044, 12.80412997]], 

     [[ 11.41995519, 2.60974757]], 

     [[ 9.89827156, 9.74580059], 
     [ 17.35840479, 7.76972241]]]) 
+0

只是好奇 - 做了發佈的解決方案爲您工作? – Divakar

+0

這是功能,但不是很實用;當函數需要迭代調用時,額外的數組操作會導致性能下降。我希望有一個更直接的方法來切分輸入,可以達到相同的結果。 –

+1

因此,對於問題中的發佈數據,您希望有三個獨立的數組,對嗎?爲了獲得這樣的單獨數組,'np.split'是最爲人所知的方法,但速度很慢,因爲存在一個分裂多數組操作。我認爲這是這裏最慢的部分。如果您對非分割輸出沒問題,則發佈的解決方案中的「out_cat_array」可能是您的輸出。 – Divakar

回答

1

一種方法 -

def inside3d(input): 
    # Get idx in 3D 
    idx3d = np.floor(input).astype(np.int) 

    # Create a similar mask as witrh 2D case, but in 3D now 
    mask3d = grid[idx3d[:,:,0], idx3d[:,:,1]]==1 

    # Count of mask matches for each index in 0th dim  
    counts = np.sum(mask3d,axis=1) 

    # Index into input to get masked matches across all elements in 0th dim 
    out_cat_array = input.reshape(-1,2)[mask3d.ravel()] 

    # Split the rows based on the counts, as the final output 
    return np.split(out_cat_array,counts.cumsum()[:-1]) 

驗證結果 -

創建3D隨機輸入:

In [91]: random_pts3d = np.random.random(size=(3, 100, 2)) * len(grid) 

隨着inside3d:

In [92]: inside3d(random_pts3d) 
Out[92]: 
[array([[ 10.71196268, 12.9875877 ], 
     [ 10.29700184, 10.00506662], 
     [ 13.80111411, 14.80514828], 
     [ 12.55070282, 14.63155383]]), array([[ 10.42636137, 12.45736944], 
     [ 11.26682474, 13.01632751], 
     [ 13.23550598, 10.99431284], 
     [ 14.86871413, 14.19079225], 
     [ 10.61103434, 14.95970597]]), array([[ 13.67395756, 10.17229061], 
     [ 10.01518846, 14.95480515], 
     [ 12.18167251, 12.62880968], 
     [ 11.27861513, 14.45609646], 
     [ 10.895685 , 13.35214678], 
     [ 13.42690335, 13.67224414]])] 

隨着內:

In [93]: inside(random_pts3d[0]) 
Out[93]: 
array([[ 10.71196268, 12.9875877 ], 
     [ 10.29700184, 10.00506662], 
     [ 13.80111411, 14.80514828], 
     [ 12.55070282, 14.63155383]]) 

In [94]: inside(random_pts3d[1]) 
Out[94]: 
array([[ 10.42636137, 12.45736944], 
     [ 11.26682474, 13.01632751], 
     [ 13.23550598, 10.99431284], 
     [ 14.86871413, 14.19079225], 
     [ 10.61103434, 14.95970597]]) 

In [95]: inside(random_pts3d[2]) 
Out[95]: 
array([[ 13.67395756, 10.17229061], 
     [ 10.01518846, 14.95480515], 
     [ 12.18167251, 12.62880968], 
     [ 11.27861513, 14.45609646], 
     [ 10.895685 , 13.35214678], 
     [ 13.42690335, 13.67224414]])