2017-08-04 21 views
1

的列表numpy的數組:廣播通過與給定一個鄰接表陣列

adj_list = [array([0,1]),array([0,1,2]),array([0,2])] 

和指數的陣列,

ind_arr = array([0,1,2]) 

目標:

A = np.zeros((3,3)) 
for i in ind_arr: 
    A[i,list(adj_list[x])] = 1.0/float(adj_list[x].shape[0]) 


目前,我寫道:

A[ind_list[:],adj_list[:]] = 1./len(adj_list[:]) 

並嘗試在此腳手架索引的各種配置。

+0

'adj_list'的數組元素是否總是採用該序列格式 - '[0,1,2 ...]'? – Divakar

+0

@Divakar他們按升序排列。但是,它們很稀疏 – bordeo

回答

1

這裏有一個方法 -

lens = np.array([len(i) for i in adj_list]) 
col_idx = np.concatenate(adj_list) 
out = np.zeros((len(lens), col_idx.max()+1)) 
row_idx = np.repeat(np.arange(len(lens)), lens) 
vals = np.repeat(1.0/lens, lens) 
out[row_idx, col_idx] = vals 

樣品輸入,輸出 -

In [494]: adj_list = [np.array([0,2]),np.array([0,1,4])] 

In [496]: out 
Out[496]: 
array([[ 0.5  , 0.  , 0.5  , 0.  , 0.  ], 
     [ 0.33333333, 0.33333333, 0.  , 0.  , 0.33333333]]) 

稀疏矩陣輸出

此外,如果你想節省內存並創建稀疏矩陣代替,這是一個簡單的擴展 -

In [506]: from scipy.sparse import csr_matrix 

In [507]: csr_matrix((vals, (row_idx, col_idx)), shape=(len(lens), col_idx.max()+1)) 
Out[507]: 
<2x5 sparse matrix of type '<type 'numpy.float64'>' 
    with 5 stored elements in Compressed Sparse Row format> 

In [508]: _.toarray() 
Out[508]: 
array([[ 0.5  , 0.  , 0.5  , 0.  , 0.  ], 
     [ 0.33333333, 0.33333333, 0.  , 0.  , 0.33333333]]) 
1

我不認爲你可以完全消除由於混合數據類型的循環,但可以減少循環嵌套雙到一個單一的一個:

A = np.zeros((2, 3)) 
for i, arr in enumerate(adj_list): 
    arr_size = len(arr) 
    A[i, :arr_size] = 1./arr_size 

A 
# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]]) 

或者,如果陣列中的數字是實際上列的位置:

A = np.zeros((2, 3)) 
for i, arr in enumerate(adj_list): 
    A[i, arr] = 1./len(arr) 

A 
# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]]) 

使用從sklearnMultiLabelBinarizer另一種選擇(但可能不會那樣有效):

from sklearn.preprocessing import MultiLabelBinarizer 
​ 
mlb = MultiLabelBinarizer() 

adj_list = [np.array([0,1]),np.array([0,1,2])] 
​ 
sizes = np.fromiter(map(len, adj_list), dtype=int) 
mlb.fit_transform(adj_list)/sizes[:,None] 

# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]])