2017-10-14 86 views
0

我試圖收集特定張量/(向量/矩陣)在角膜張量內的索引。因此,我試圖使用tf.gathertf.where來獲取在收集功能中使用的索引。張量流條件軸

然而,當測試相等時,tf.where爲匹配值提供元素明智的索引。我希望能夠找到張量(向量)的索引(行),這些索引與另一個相等。

這對找到張量中與一組感興趣的熱點向量相匹配的單向量向量特別有用。

我有一些代碼來說明到目前爲止的缺點:

# standard 
import tensorflow as tf 
import numpy as np 
from sklearn.preprocessing import LabelBinarizer 
sess = tf.Session() 

# one-hot vector encoding labels 
l = LabelBinarizer() 
l.fit(['a','b','c']) 

# input tensor 
t = tf.constant(l.transform(['a','a','c','b', 'a'])) 

# find the indices where 'c' is label 
# ***THIS WORKS*** 
np.all(t.eval(session = sess) == l.transform(['c']), axis = 1) 

# We need to do everything in tensorflow and then wrap in Lambda layer for keras so... 
from keras import backend as K 
# ***THIS DOES NOT WORK*** 
K.all(t.eval(session = sess) == l.transform(['c']), axis = 1) 

# go on from here to get smaller subset of vectors from another tensor with the indicies given by `tf.gather` 

顯然上面的代碼顯示我曾嘗試軸得到這個條件的工作,並在numpy的確實很好,但tensorflow版本是不容易從numpy移植。

有沒有更好的方法來做到這一點?

回答

1

同樣對你做了什麼,我們可以用tf.reduce_all這是tensorflow相當於np.all

tf.reduce_all(t.eval(session = sess) == l.transform(['c']), axis = 1)