2016-03-21 57 views
3

通常,我試圖將距離矩陣拆分爲K摺疊。具體地,對於3 x 3的情況下,我的距離矩陣可能是這樣的:更快地將ndarray的值「分配」到基於賦值的其他ndarray中?

full = np.array([ 
    [0, 0, 3], 
    [1, 0, 1], 
    [2, 1, 0] 
]) 

我也有隨機產生的分配,其長度等於所述總和超過在距離矩陣的所有元素的列表。對於K = 3情況下,它可能是這樣的:

assignments = np.array([0, 1, 0, 2, 1, 1, 0, 0]) 

我想創建K = 33 x 3矩陣零,其中距離矩陣的值是「分佈式」根據作業列表。代碼是比文字更精確,所以這裏是我當前的嘗試:

def assign(full, assignments): 
    folds = [np.zeros(full.shape) for _ in xrange(np.max(assignments) + 1)] 
    rows, cols = full.shape 
    a = 0 
    for r in xrange(rows): 
     for c in xrange(cols): 
      for i in xrange(full[r, c]): 
       folds[assignments[a]][r, c] += 1 
       a += 1 
    return folds 

此作品(慢),而在這個例子中,

folds = assign(full, assignments) 
for f in folds: 
    print f 

回報

[[ 0. 0. 2.] 
[ 0. 0. 0.] 
[ 1. 1. 0.]] 
[[ 0. 0. 1.] 
[ 0. 0. 1.] 
[ 1. 0. 0.]] 
[[ 0. 0. 0.] 
[ 1. 0. 0.] 
[ 0. 0. 0.]] 

的期望。不過,我目前的嘗試速度很慢,特別是對於的情況下N大。我怎樣才能提高這個功能的速度?我應該在這裏使用一些numpy魔術嗎?

我有一個想法是轉換爲sparse矩陣和循環非零條目。這將不僅有助於一點,但是,

回答

1

您可以使用add.at做緩衝到位操作:

import numpy as np 

full = np.array([ 
    [0, 0, 3], 
    [1, 0, 1], 
    [2, 1, 0] 
]) 

assignments = np.array([0, 1, 0, 2, 1, 1, 0, 0]) 

res = np.zeros((np.max(assignments) + 1,) + full.shape, dtype=int) 

r, c = np.nonzero(full) 
n = full[r, c] 

r = np.repeat(r, n) 
c = np.repeat(c, n) 

np.add.at(res, (assignments, r, c), 1) 

print(res) 
+0

不錯 - 我不知道像'add'這樣的通用函數有一個[at](http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.ufunc.at.html )修飾符。另外,對於和我一樣的人來說,我發現稀疏矩陣的「非零」方法對此很有幫助。 –

1

你只需要弄清楚處於平坦輸出項目將每次獲得遞增,然後用bincount它們聚集:

def assign(full, assignments): 
    assert len(assignments) == np.sum(full) 

    rows, cols = full.shape 
    n = np.max(assignments) + 1 

    full_flat = full.reshape(-1) 
    full_flat_non_zero = full_flat != 0 
    full_flat_indices = np.repeat(np.where(full_flat_non_zero)[0], 
            full_flat[full_flat_non_zero]) 
    folds_flat_indices = full_flat_indices + assignments*rows*cols 

    return np.bincount(folds_flat_indices, 
         minlength=n*rows*cols).reshape(n, rows, cols) 

>>> assign(full, assignments) 
array([[[0, 0, 2], 
     [0, 0, 0], 
     [1, 1, 0]], 

     [[0, 0, 1], 
     [0, 0, 1], 
     [1, 0, 0]], 

     [[0, 0, 0], 
     [1, 0, 0], 
     [0, 0, 0]]]) 

你可能想打印出你的例子中的每一箇中間數組,看看究竟發生了什麼。

+0

酷使用bincount的。我最終選擇了另一個,因爲add.at更直接。 –

+0

'add.at'唯一的問題是它通常非常慢:如果使用'bincount'的速度快10倍,不會感到驚訝。 – Jaime