要擁有完全動態的分區,您可以使用運算符來返回一個帶有動態形狀的張量,而不是Python時間固定數量的張量,但問題是張量的尺寸必須是矩形的,而且分區的長度可能不同。爲了解決這個問題,你可能會把你的可變大小列表編碼成shape()或使用TensorArray
的字符串。這裏有一個方法,通過使用PNG編碼陣列< =>字符串轉換
def dynamic_partition_png(vals, idx, max_partitions):
"""Encodes output of dynamic partition as a Tensor of png-encoded strings."""
max_idx = tf.reduce_max(idx)
max_vals = tf.reduce_max(idx)
with tf.control_dependencies([tf.Assert(max_vals<256, ["vals must be <256"])]):
outputs = tf.dynamic_partition(vals, idx, num_partitions=max_partitions)
png_outputs = []
dummy_png = tf.image.encode_png(([[[2]]]))
not_empty_ops = [] # ops that detect empty lists that aren't at the end
for i, o in enumerate(outputs):
reshaped_o = tf.reshape(tf.cast(o, tf.uint8), [-1, 1, 1])
png_output = tf.cond(tf.size(reshaped_o)>0, lambda: tf.image.encode_png(reshaped_o), lambda: dummy_png)
png_outputs.append(png_output)
not_empty_ops.append(tf.logical_or(i>max_idx, tf.size(reshaped_o)>0))
packed_tensor = tf.pack(png_outputs)
no_illegal_empty_lists = tf.reduce_all(tf.pack(not_empty_ops))
with tf.control_dependencies([tf.Assert(no_illegal_empty_lists, ["empty lists must be last"])]):
result = packed_tensor[:max_idx+1]
return result
def decode(p):
return tf.image.decode_png(p)[:, 0, 0]
sess = tf.Session()
vals = tf.constant([1,2,3,4,5])
idx = [0, 1, 1, 1, 1]
tf_vals = dynamic_partition_png(vals, idx, 3)
print(sess.run(decode(tf_vals[0]))) # => [1 2]
print(sess.run(decode(tf_vals[1]))) # => [3 4 5]
print(sess.run(decode(tf_vals[2]))) # => slice index 2 of dimension 0 out of bounds
什麼是在傳遞一個整數的類,這是建立模型的危害做呢?你可以在你的類中有'edit_partitions()'函數 – martianwars
請注意'tf.dynamic_partition'的輸出是一個Python列表。所以你可以確保空張量始終在最後,並使用Python列表索引來切斷尾部。您可能需要將其分成兩個'sess.run'調用,以便將長度變爲Python-land。 –