2017-08-14 94 views
0

建立在此question我期待在滿足tf.where條件的行中第一次更新二維張量的值。這裏是我使用模擬示例代碼:Tensorflow更新每行中的第一個匹配元素

tf.reset_default_graph() 
graph = tf.Graph() 
with graph.as_default(): 
    val = "hello" 
    new_val = "goodbye" 
    matrix = tf.constant([["word","hello","hello"], 
          ["word", "other", "hello"], 
          ["hello", "hello","hello"], 
          ["word", "word", "word"] 
         ]) 
    matching_indices = tf.where(tf.equal(matrix, val)) 
    first_matching_idx = tf.segment_min(data = matching_indices[:, 1], 
           segment_ids = matching_indices[:, 0]) 

sess = tf.InteractiveSession(graph=graph) 
print(sess.run(first_matching_idx)) 

這將輸出[1,2,0],其中1是第一問候的第1行中的位置,該圖2是第一的放置你好在第2行,並且0是第3行中第一個問候的位置。

但是,我找不到一種方法來獲取第一個匹配索引以更新新值 - 基本上我希望第一個「你好」變成「再見」。我嘗試過使用tf.scatter_update(),但它似乎不適用於2D張量。有沒有辦法按照描述修改二維張量?

+0

你使用scatter_update的想法似乎有前途的。 scatter_update的文檔似乎表明它可以處理高維張量。也許仍然值得找出具體問題是什麼。請注意,該問題也可以通過將行索引*行長度偏移索引來轉換爲1D問題。 –

回答

0

一個簡單的解決方法是使用tf.py_func與numpy的陣列

def ch_val(array, val, new_val): 
    idx = np.array([[s, list(row).index(val)] 
       for s, row in enumerate(array) if val in row]) 
    idx = tuple((idx[:, 0], idx[:, 1])) 
    array[idx] = new_val 
    return array 

... 
matrix = tf.Variable([["word","hello","hello"], 
         ["word", "other", "hello"], 
         ["hello", "hello","hello"], 
         ["word", "word", "word"] 
        ]) 
matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(matrix)) 
    # results: [['word' 'goodbye' 'hello'] 
     ['word' 'other' 'goodbye'] 
     ['goodbye' 'hello' 'hello'] 
     ['word' 'word' 'word']] 
+0

謝謝!訓練時py_func會導致速度問題嗎?如果每個批次都執行此操作,我擔心python函數會顯着減慢速度 – reese0106

+0

它不會顯着減慢速度,但會比沒有'tf.py_func'的最佳版本慢 –

相關問題