我有以下問題:NumPy的apply_along_axis錯D型使用NumPy的時候infered
代碼:
import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)
輸出:
['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!
我可以看到,NumPy的以某種方式從第一推斷數據類型函數返回的值。我想出了以下解決辦法 - 從函數返回與NumPy陣列明確指出D型,而不是字符串,重塑結果:
def get_label_2(x):
if x.sum() <= 10:
return np.array(['SMALL'], dtype='|S5')
else:
return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])
你知道這個問題更優雅的解決方案?
它使用第一個輸入從測試計算推斷dtype。如果返回'BIG',則字符串大小設置爲3個字符長。 – hpaulj