2017-06-02 90 views
4

我在Theano中有一個實數矩陣,我想要生成另一個矩陣,使得在新矩陣的每一列中都有一個1.0,否則爲0.0。 1.0應該指示輸入矩陣列中最大值的位置。如何使用Theano將列中的最大值替換爲1.0?

例如。

0.0 0.0 0.0 1.0 
1.0 0.0 1.0 0.0 
0.0 1.0 0.0 0.0 

,我迄今使用的解決方案如下::

tmp = T.max(inp, axis = 0).dimshuffle('x',0) 
out = T.switch(T.eq(tmp, inp), 1.0, 0.0) 
如果下面的矩陣被用作輸入

1.0 2.0 3.0 5.0 
2.1 0.0 4.0 0.0 
0.0 3.0 1.0 4.0 

下面的矩陣已到作爲輸出生成

此解決方案似乎可行,但我不確定它有多強大。主要關注的是我比較當前值是否等於剛好到列中的最大值。這是否會發生,由於一些「四捨五入」的錯誤,最大的價值不會被我們認出來?

回答

1

試試這個:

out = T.eye(3)[T.argmax(inp, axis=0)].T # replace 3 with number of rows 

這將輸出:

0.0 0.0 0.0 1.0 
1.0 0.0 1.0 0.0 
0.0 1.0 0.0 0.0 
1

我會使用argmax函數而不是最大& dimshuffle方法。 Argmax返回最大值的索引,而不是值本身。

tmp = T.argmax(inp, axis = 0) 

然後,您可以初始化一個矩陣,全零和使用您的TMP陣列所需指數爲1.0(我沒能測試/現在提供了這部分代碼,但它應該是微不足道)