2016-12-31 76 views
2

方法tf.dynamic_partition中的變量num_partitions不是Tensor,而是int。 因此,如果事先不知道分區的數量,則不能通過例如統計唯一值的數量來推斷數據,也不能由tf.placeholder給出。 如何在這種動態場景中使用這種方法?dynamic_partition with dynamic num_partitions

如果這是不可能的,可能的解決方法是將此參數的值設置爲某個上限。然後在運行時會有一些空的列表。問題是這些空列表可以如何消除?

謝謝!

+0

什麼是在傳遞一個整數的類,這是建立模型的危害做呢?你可以在你的類中有'edit_partitions()'函數 – martianwars

+0

請注意'tf.dynamic_partition'的輸出是一個Python列表。所以你可以確保空張量始終在最後,並使用Python列表索引來切斷尾部。您可能需要將其分成兩個'sess.run'調用,以便將長度變爲Python-land。 –

回答

2

要擁有完全動態的分區,您可以使用運算符來返回一個帶有動態形狀的張量,而不是Python時間固定數量的張量,但問題是張量的尺寸必須是矩形的,而且分區的長度可能不同。爲了解決這個問題,你可能會把你的可變大小列表編碼成shape()或使用TensorArray的字符串。這裏有一個方法,通過使用PNG編碼陣列< =>字符串轉換

def dynamic_partition_png(vals, idx, max_partitions): 
    """Encodes output of dynamic partition as a Tensor of png-encoded strings.""" 
    max_idx = tf.reduce_max(idx) 
    max_vals = tf.reduce_max(idx) 
    with tf.control_dependencies([tf.Assert(max_vals<256, ["vals must be <256"])]): 
     outputs = tf.dynamic_partition(vals, idx, num_partitions=max_partitions) 
    png_outputs = [] 
    dummy_png = tf.image.encode_png(([[[2]]])) 
    not_empty_ops = [] # ops that detect empty lists that aren't at the end 
    for i, o in enumerate(outputs): 
     reshaped_o = tf.reshape(tf.cast(o, tf.uint8), [-1, 1, 1]) 
     png_output = tf.cond(tf.size(reshaped_o)>0, lambda: tf.image.encode_png(reshaped_o), lambda: dummy_png) 
     png_outputs.append(png_output) 
     not_empty_ops.append(tf.logical_or(i>max_idx, tf.size(reshaped_o)>0)) 
    packed_tensor = tf.pack(png_outputs) 
    no_illegal_empty_lists = tf.reduce_all(tf.pack(not_empty_ops)) 
    with tf.control_dependencies([tf.Assert(no_illegal_empty_lists, ["empty lists must be last"])]): 
     result = packed_tensor[:max_idx+1] 
    return result 

def decode(p): 
    return tf.image.decode_png(p)[:, 0, 0] 

sess = tf.Session() 
vals = tf.constant([1,2,3,4,5]) 
idx = [0, 1, 1, 1, 1] 
tf_vals = dynamic_partition_png(vals, idx, 3) 
print(sess.run(decode(tf_vals[0]))) # => [1 2] 
print(sess.run(decode(tf_vals[1]))) # => [3 4 5] 
print(sess.run(decode(tf_vals[2]))) # => slice index 2 of dimension 0 out of bounds