2012-08-31 43 views
0

我試圖使用ufunc將一個N * 1 numpy數組整數有效地映射到N * 3 numpy浮點數組。使用ufunc映射numpy數組

我有什麼至今:

map = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)} 
ufunc = numpy.frompyfunc(lambda x: numpy.array(map[x], numpy.float32), 1, 1) 

input = numpy.array([1, 2, 3], numpy.int32) 

ufunc(input)給出了一個3×3陣列D型對象。我想這個數組,但與D型float32。

+2

'map'和'input'是Python內置函數。最好不要爲這些名稱分配新的值,因爲它使得很難訪問Python內置函數。 – unutbu

+0

'frompyfunc'的文檔說「返回的ufunc總是返回PyObject數組」。無論這個原因是什麼,有一個相當簡單的解決方法:提交一個適當的輸入類型的輸出矩陣作爲「out」參數。 – Alexey

回答

1

你可以使用np.hstack

import numpy as np 
mapping = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)} 
ufunc = np.frompyfunc(lambda x: np.array(mapping[x], np.float32), 1, 1, dtype = np.float32) 

data = np.array([1, 2, 3], np.int32) 
result = np.hstack(ufunc(data)) 
print(result) 
# [ 0. 0. 0. 0.5 0.5 0.5 1. 1. 1. ] 
print(result.dtype) 
# float32 
print(result.shape) 
# (9,) 
1

您可以使用ndarray看中指數得到相同的結果,我認爲它應該比frompyfunc快:

map_array = np.array([[0,0,0],[0,0,0],[0.5,0.5,0.5],[1,1,1]], dtype=np.float32) 
index = np.array([1,2,3,1]) 
map_array[index] 

或者你也可以使用列表理解:

map = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)} 
np.array([map[i] for i in [1,2,3,1]], dtype=np.float32)  
+0

輸入列表非常大,所以我試圖避免創建中間列表或數組。 –

1

如果你的映射是一個numpy數組,你可以使用花式索引g下這樣的​​:

>>> valmap = numpy.array([(0, 0, 0), (0.5, 0.5, 0.5), (1, 1, 1)]) 
>>> input = numpy.array([1, 2, 3], numpy.int32) 
>>> valmap[input-1] 
array([[ 0. , 0. , 0. ], 
     [ 0.5, 0.5, 0.5], 
     [ 1. , 1. , 1. ]]) 
1

除非我誤讀了文檔的np.frompyfunc上一個標量輸出對象確實是:使用ndarray作爲輸入時,你會得到一個ndarraydtype=obj

一種解決方法是使用np.vectorize功能:

F = np.vectorize(lambda x: mapper.get(x), 'fff') 

在這裏,我們迫使F的輸出的dtype爲3個浮子(因此'fff')。

>>> mapper = {1: (0, 0, 0), 2: (0.5, 1.0, 0.5), 3: (1, 2, 1)} 
>>> inp = [1, 2, 3] 
>>> F(inp) 
(array([ 0. , 0.5, 1. ], dtype=float32), array([ 0., 0.5, 1.], dtype=float32), array([ 0. , 0.5, 1. ], dtype=float32)) 

OK,並不完全符合我們想:這是三個浮點陣列的元組(因爲我們送「FFF」),第一陣列等價於[mapper[i][0] for i in inp]。因此,通過一些操作:

>>> np.array(F(inp)).T 
array([[ 0. , 0. , 0. ], 
     [ 0.5, 0.5, 0.5], 
     [ 1. , 1. , 1. ]], dtype=float32)