的字符串。這裏有一個方法,通過使用PNG編碼陣列< =>字符串轉換
def dynamic_partition_png(vals, idx, max_partitions):
"""Encodes output of dynamic partition as a Tensor of png-encoded strings."""
max_idx = tf.reduce_max(idx)
max_vals = tf.reduce_max(idx)
with tf.control_dependencies([tf.Assert(max_vals<256, ["vals must be <256"])]):
outputs = tf.dynamic_partition(vals, idx, num_partitions=max_partitions)
png_outputs = []
dummy_png = tf.image.encode_png(([[[2]]]))
not_empty_ops = [] # ops that detect empty lists that aren't at the end
for i, o in enumerate(outputs):
reshaped_o = tf.reshape(tf.cast(o, tf.uint8), [-1, 1, 1])
png_output = tf.cond(tf.size(reshaped_o)>0, lambda: tf.image.encode_png(reshaped_o), lambda: dummy_png)
not_empty_ops.append(tf.logical_or(i>max_idx, tf.size(reshaped_o)>0))
packed_tensor = tf.pack(png_outputs)
no_illegal_empty_lists = tf.reduce_all(tf.pack(not_empty_ops))
with tf.control_dependencies([tf.Assert(no_illegal_empty_lists, ["empty lists must be last"])]):
result = packed_tensor[:max_idx+1]
return result
def decode(p):
return tf.image.decode_png(p)[:, 0, 0]
sess = tf.Session()
vals = tf.constant([1,2,3,4,5])
idx = [0, 1, 1, 1, 1]
tf_vals = dynamic_partition_png(vals, idx, 3)
print(sess.run(decode(tf_vals[0]))) # => [1 2]
print(sess.run(decode(tf_vals[1]))) # => [3 4 5]
print(sess.run(decode(tf_vals[2]))) # => slice index 2 of dimension 0 out of bounds
