建立在此question我期待在滿足tf.where條件的行中第一次更新二維張量的值。這裏是我使用模擬示例代碼:Tensorflow更新每行中的第一個匹配元素
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
val = "hello"
new_val = "goodbye"
matrix = tf.constant([["word","hello","hello"],
["word", "other", "hello"],
["hello", "hello","hello"],
["word", "word", "word"]
])
matching_indices = tf.where(tf.equal(matrix, val))
first_matching_idx = tf.segment_min(data = matching_indices[:, 1],
segment_ids = matching_indices[:, 0])
sess = tf.InteractiveSession(graph=graph)
print(sess.run(first_matching_idx))
這將輸出[1,2,0],其中1是第一問候的第1行中的位置,該圖2是第一的放置你好在第2行,並且0是第3行中第一個問候的位置。
但是,我找不到一種方法來獲取第一個匹配索引以更新新值 - 基本上我希望第一個「你好」變成「再見」。我嘗試過使用tf.scatter_update(),但它似乎不適用於2D張量。有沒有辦法按照描述修改二維張量?
你使用scatter_update的想法似乎有前途的。 scatter_update的文檔似乎表明它可以處理高維張量。也許仍然值得找出具體問題是什麼。請注意,該問題也可以通過將行索引*行長度偏移索引來轉換爲1D問題。 –