2017-09-07 70 views
1

我有以下問題:NumPy的apply_along_axis錯D型使用NumPy的時候infered

代碼:

import numpy as np 
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG' 
arr = np.array([[1, 2], [30, 40]]) 
print np.apply_along_axis(get_label, 1, arr) 
arr = np.array([[30, 40], [1, 2]]) 
print np.apply_along_axis(get_label, 1, arr) 

輸出:

['SMALL' 'BIG'] 
['BIG' 'SMA'] # String 'SMALL' is stripped! 

我可以看到,NumPy的以某種方式從第一推斷數據類型函數返回的值。我想出了以下解決辦法 - 從函數返回與NumPy陣列明確指出D型,而不是字符串,重塑結果:

def get_label_2(x): 
    if x.sum() <= 10: 
     return np.array(['SMALL'], dtype='|S5') 
    else: 
     return np.array(['BIG'], dtype='|S5') 
arr = np.array([[30, 40], [1, 2]]) 
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0]) 

你知道這個問題更優雅的解決方案?

+1

它使用第一個輸入從測試計算推斷dtype。如果返回'BIG',則字符串大小設置爲3個字符長。 – hpaulj

回答

1

您可以使用np.where

arr1 = np.array([[1, 2], [30, 40]]) 
arr2 = np.array([[30, 40], [1, 2]]) 

print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG')) 
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG')) 
['SMALL' 'BIG'] 
['BIG' 'SMALL'] 

在功能:

def get_label(x, threshold, axis=1, label1='SMALL', label2='BIG'): 
    return np.where(x.sum(axis=axis) <= threshold, label1, label2) 
0

apply_along_axis不是一個完美的解決方案;這很方便,但並不快。本質上它是

In [277]: get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG' 
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])]) 
Out[279]: 
array(['BIG', 'SMALL'], 
     dtype='<U5') 
In [280]: res = np.zeros((2,),dtype='S5') 
In [281]: arr = np.array([[30,40],[1,2]]) 
In [282]: for i in range(2): 
    ...:  res[i] = get_label(arr[i,:]) 
    ...:  
In [283]: res 
Out[283]: 
array([b'BIG', b'SMALL'], 
     dtype='|S5') 

除了它概括形狀和推導res dtype。

用一個簡單的「迭代行」情況是這樣,你也可以同樣做:

In [278]: np.array([get_label(row) for row in np.array([[1,2],[30,40]])]) 
Out[278]: 
array(['SMALL', 'BIG'], 
     dtype='<U5') 
In [279]: np.array([get_label(row) for row in np.array([[30,40],[1,2]])]) 
Out[279]: 
array(['BIG', 'SMALL'], 
     dtype='<U5') 

優雅的解決方案是避免Python層面的循環,顯式或隱,而是使用編譯陣列的方法,如給sum一個軸:

In [284]: arr.sum(axis=1) 
Out[284]: array([70, 3])