2013-03-11 24 views
3

如何在通過像__rsub__這樣的右手運算符來訪問numpy數組的屬性?參數類型在正常運算符和反射運算符超載方面不同(__sub__/__rsub__)

我寫在Python一個非常簡單的類,它定義了兩個功能:

class test(object): 
    def __sub__(self, other): 
     return other 

    def __rsub__(self, other): 
     return other 

基本上他們應該這樣做。左手操作如預期__sub__作品,但它似乎numpy的陣列被剝離其在右手操作特性

from skimage import data 
from skimage.color import rgb2gray 
lena = data.lena() 

grayLena = rgb2gray(lena) 
t = test() 

## overloaded - operator 
left_hand = t - grayLena 
print left_hand 
# Output: 
#array([[ 0.60802863, 0.60802863, 0.60779059, ..., 0.64137412, 
# 0.57998235, 0.46985725], 
# [ 0.60802863, 0.60802863, 0.60779059, ..., 0.64137412, 
# 0.57998235, 0.46985725], 
# [ 0.60802863, 0.60802863, 0.60779059, ..., 0.64137412, 
# 0.57998235, 0.46985725], 
# ..., 
# [ 0.13746353, 0.13746353, 0.16881412, ..., 0.37271804, 
# 0.35559529, 0.34377725], 
# [ 0.14617059, 0.14617059, 0.18730588, ..., 0.36788784, 
# 0.37292549, 0.38467529], 
# [ 0.14617059, 0.14617059, 0.18730588, ..., 0.36788784, 
# 0.37292549, 0.38467529]]) 

right_hand = grayLena - t 
print right_hand 
# Output: 
# array([[0.6080286274509803, 0.6080286274509803, 0.6077905882352941, ..., 
# 0.6413741176470589, 0.5799823529411765, 0.4698572549019608], 
# [0.6080286274509803, 0.6080286274509803, 0.6077905882352941, ..., 
# 0.6413741176470589, 0.5799823529411765, 0.4698572549019608], 
# [0.6080286274509803, 0.6080286274509803, 0.6077905882352941, ..., 
# 0.6413741176470589, 0.5799823529411765, 0.4698572549019608], 
# ..., 
# [0.1374635294117647, 0.1374635294117647, 0.1688141176470588, ..., 
# 0.3727180392156863, 0.35559529411764706, 0.34377725490196076], 
# [0.1461705882352941, 0.1461705882352941, 0.18730588235294118, ..., 
# 0.3678878431372549, 0.37292549019607846, 0.3846752941176471], 
# [0.1461705882352941, 0.1461705882352941, 0.18730588235294118, ..., 
# 0.3678878431372549, 0.37292549019607846, 0.3846752941176471]], dtype=object) 

所以這兩個操作之間的不同之處在於__rsub__接收D型數組=目的。如果我只是設置這個數組的dtype,一切都會正常工作。

但是,它只適用於__rsub__之外的返回值。在我的__rsub__我得到的只是垃圾,我不能轉換回,即如果我做

npArray = np.array(other, dtype=type(other)) 

我得到的類型(在我的情況下,浮動)的一維數組。但由於某種原因,形狀信息丟失了。有沒有人完成這個或一個想法如何我可以訪問數組的原始屬性(形狀和類型)?

回答

1

我不知道什麼樣的ndarray的機器內部的精確控制流程,但發生的事情在你的情況或多或少是明確的:

什麼ndarray被委託給你的對象的__rsub__方法不是整體減法操作,但是從數組中的每個項目減去對象。顯然,當它必須將操作委託給對象的方法時,返回類型被設置爲object,而不管返回的是什麼。你可以檢查它與你的代碼的這個輕微修改:

class test(object): 
    def __sub__(self, other): 
     return other 

    def __rsub__(self, other): 
     return other if other != 1 else 666 

In [11]: t = test() 

In [12]: t - np.arange(4) 
Out[12]: array([0, 1, 2, 3]) 

In [13]: np.arange(4) - t 
Out[13]: array([0, 666, 2, 3], dtype=object) 

我不認爲有一個簡單的方法來覆蓋這種行爲。您可以嘗試making test a subclass of ndarray__array_priority__和濫用一點點的__array_wrap__方法:

class test(np.ndarray): 
    __array_priority__ = 100 

    def __new__(cls): 
     obj = np.int32([1]).view(cls) 
     return obj 

    def __array_wrap__(self, arr, context) : 
     if context is not None : 
      ufunc = context[0] 
      args = context[1] 
      if ufunc == np.subtract : 
       if self is args[0] : 
        return args[1] 
       elif self is args[1] : 
        return args[0] 
     return arr 

現在:

>>> t = test() 
>>> np.arange(4) - t 
array([0, 1, 2, 3]) 
>>> t - np.arange(4) 
array([0, 1, 2, 3]) 

但是:

>>> np.arange(4) + t 
test([1, 2, 3, 4]) 
>>> t + np.arange(4) 
test([1, 2, 3, 4]) 

這是一個有點浪費,因爲我們正在爲t內部的1添加到數組中的每個值,然後默默地丟棄克,但我想不出有什麼方法可以壓倒這一點。

+0

太好了,我會試試看。我只需要它是出於落後的兼容性原因,所以它不一定要高效,它只需要給出正確的結果。 – 2013-03-11 20:23:22

+0

@ user2156909更多的指針:我已經爲單個元素數組創建了'test',以便它可以針對任何形狀進行廣播,先用0-D數組嘗試它,然後失敗。你可能也想考慮給數組賦予什麼值,我選擇'1'來表明只有當'ufunc'是'subtract'時,我們才放棄操作的結果。如果你的對象可能出現在其他操作中,你可能想要檢查更多的'ufuncs',或給它另一個值。 – Jaime 2013-03-11 20:40:11

+0

在我的情況下,子類化ndarray無論如何都是更好的方法,因爲我只是想減去測試類中的數據數組。如果想要調用一個更花哨的函數,'__array_wrap__'也可以正常工作。謝謝 – 2013-03-11 21:03:02