粗略地說,一個需要從密集indices
和weights
陣列,而不是直接做到這一點上tf.SparseTensor
收集正確的價值觀。我將在下面發佈我的代碼,它假設索引描述的是同一行是連續的,因此可以使用開始索引和結束索引來標識。
# Suppose we have the following cookbook:
# The first 3 rows are one-hot basis and these words have their own embeddings
# The last 2 rows corresponds to rare words and their embeddings are linear combinations of common words.
# Thus basis coefficients and weights are:
# indices:
#[[0, x, x]
# [1, x, x]
# [2, x, x]
# [1, 2, x]
# [0, 2, x]]
# weights:
#[[1.0, x, x]
# [1.0, x, x]
# [1.0, x, x]
# [0.1, 0.9, x]
# [0.7, 0.3, x]]
# embedding basis (for word 0, word 1 and word 2 respectively):
#[[1, 2, 3, 4, 5]
# [6, 7, 8, 9, 10]
# [11, 12, 13, 14, 15]]
# Which implies, the embeddings for word 3 and word 4 are:
# [6, 7, 8, 9, 10] * 0.1 + [11, 12, 13, 14, 15] * 0.9 = [ 10.5, 11.5, 12.5, 13.5, 14.5]
# [1, 2, 3, 4, 5] * 0.7 + [11, 12, 13, 14, 15] * 0.3 = [ 4., 5., 6., 7., 8.]
X = tf.constant(np.array(list(range(1, 16))).reshape((3, 5)), dtype=tf.float32)
# For word i, its index range in sp_weights_vals and sp_ids_val is
# [start_index_in_sp_indices[i], end_index_in_sp_indices[i])
# index_range_in_sp_indices[i] is the index range of word i
# e.g.: For word 3, the relevant params are the 3rd and 4th ones (i.e.: interval [3, 5], including 3 and excluding 5)
# Equivalent to [[0, 0], [1, 0], [2, 0], [3, 0], [3, 1], [4, 0], [4, 1]]
start_index_in_sp_indices = [0, 1, 2, 3, 5]
end_index_in_sp_indices = [1, 2, 3, 5, 7]
word_ids_in_sentence = tf.placeholder(shape=[6], dtype=tf.int32)
sp_shape = tf.placeholder(tf.int64) # mini-batch size * sparsity
# Gather proper indices
start_indices_in_sentence = tf.gather(start_index_in_sp_indices, word_ids_in_sentence)
end_indices_in_sentence = tf.gather(end_index_in_sp_indices, word_ids_in_sentence)
print(sess.run(start_indices_in_sentence, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
print(sess.run(end_indices_in_sentence, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
# Not supported due to complicated shape
# elems = (start_indices_in_sentence, end_indices_in_sentence)
# word_embedding_indices = tf.foldl(lambda a, x: tf.concat(a, tf.range(x[0], x[1])), elems, initializer=[])
print("*" * 50)
indices_to_gather = []
sp_indices_to_feed = []
for i in range(6):
indices_to_gather.append(tf.range(start_indices_in_sentence[i], end_indices_in_sentence[i]))
sp_indices_to_feed.append(tf.stack(tf.map_fn(lambda x: (i, x),
tf.range(end_indices_in_sentence[i] - start_indices_in_sentence[i]),
dtype=(tf.int32, tf.int32)),
axis=1))
print("check indices to gather")
print(sess.run(indices_to_gather, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
indices_to_gather = tf.concat(indices_to_gather, axis=0)
print(sess.run(indices_to_gather, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
print("check indices to feed")
print(sess.run(sp_indices_to_feed, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
sp_indices_to_feed = tf.to_int64(tf.concat(sp_indices_to_feed, axis=0))
print(sess.run(sp_indices_to_feed, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
# sp_indices_table = [[0, 0], [1, 0], [2, 0], [3, 0], [3, 1], [4, 0], [4, 1]]
sp_ids_val_table = tf.constant(np.array([0, 1, 2, 1, 2, 0, 2], dtype=np.int32))
sp_weights_val_table = tf.constant(np.array([1.0, 1.0, 1.0, 0.1, 0.9, 0.7, 0.3], dtype=np.float32))
sp_ids_val_to_feed = tf.gather(sp_ids_val_table, indices_to_gather)
sp_weights_val_to_feed = tf.gather(sp_weights_val_table, indices_to_gather)
print(sess.run([sp_indices_to_feed], feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
print(sess.run([sp_ids_val_to_feed], feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
print(sess.run([sp_weights_val_to_feed], feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
sp_ids = tf.SparseTensor(sp_indices_to_feed, sp_ids_val_to_feed, sp_shape)
sp_weights = tf.SparseTensor(sp_indices_to_feed, sp_weights_val_to_feed, sp_shape)
y = tf.nn.embedding_lookup_sparse(X, sp_ids, sp_weights, combiner="sum")
word_embedding_indices = tf.concat(indices_to_gather, axis=0)
print(sess.run(word_embedding_indices, feed_dict={word_ids_in_sentence: [1, 3, 2, 1, 4, 3]}))
print("*" * 50)
word_embeddings = sess.run(y, feed_dict={
word_ids_in_sentence: [1, 3, 2, 1, 4, 3],
sp_shape: [6, 3]})
print(word_embeddings)
哦,我明白了。它使用了dynamic_partition和gather的巧妙結合。這個想法是,不要形成一個SparseTensor。相反,只需在密集的'indices'和'values'數組上執行'dynamic_partition'和'gather'即可。 – soloice
這有助於很多!順便說一句,這段代碼似乎只適用於單行。我如何提取多行(即使有重複)?說,如果我想提取第一,第三和第三排?也就是說,在輸入句子中,可能會出現多次出現的單詞。我知道我可以運行'sparse_slice'幾次,使用'tf.concat'將它們串聯起來,但有沒有更好的辦法? – soloice
最後我明白了。我可以對密集的'indices'和'values'進行類似的分區。我明天會發佈一個解決方案,它大量使用'tf.gather'。 – soloice