2017-06-22 70 views
1

我想用TensorFlow的tf.image.rot90(),後者只需要一個圖像在時間創建函數batch_rot90(batch_of_images),前者應採取間歇n個圖像的一次(形狀= [N,X,Y,F])。如何通過張量在TensorFlow元素循環?

所以很自然的,人們應該只是通過批處理中的所有圖像itterate和旋轉它們一個接一個。在numpy中,這看起來像:

def batch_rot90(batch): 
    for i in range(batch.shape[0]): 
    batch_of_images[i] = rot90(batch[i,:,:,:]) 
    return batch 

這是如何在TensorFlow中完成的? 使用tf.while_loop我得到了他的遠:

batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])  
def batch_rot90(batch, k, name=''): 
     i = tf.constant(0) 
     def cond(batch, i): 
     return tf.less(i, tf.shape(batch)[0]) 
     def body(im, i): 
     batch[i] = tf.image.rot90(batch[i], k) 
     i = tf.add(i, 1) 
     return batch, i 
     r = tf.while_loop(cond, body, [batch, i]) 
     return r 

但分配給im[i]是不允許的,而且我感到困惑的是什麼有R返回。

我知道有可能是使用tf.batch_to_space()這種特殊情況的方法,但我相信它有可能需要某種形式的循環了。

回答

0

更新答:

x = tf.placeholder(tf.float32, shape=[2, 3]) 

def cond(batch, output, i): 
    return tf.less(i, tf.shape(batch)[0]) 

def body(batch, output, i): 
    output = output.write(i, tf.add(batch[i], 10)) 
    return batch, output, i + 1 

# TensorArray is a data structure that support dynamic writing 
output_ta = tf.TensorArray(dtype=tf.float32, 
       size=0, 
       dynamic_size=True, 
       element_shape=(x.get_shape()[1],)) 
_, output_op, _ = tf.while_loop(cond, body, [x, output_ta, 0]) 
output_op = output_op.stack() 

with tf.Session() as sess: 
    print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]})) 

我想你應該考慮使用tf.scatter_update更新批一個圖像,而不是使用batch[i] = ...。詳情請參閱this link。在你的情況,我建議身體的第一行更改爲:

tf.scatter_update(batch, i, tf.image.rot90(batch[i], k)) 

+0

這並不因爲工作'batch'只是一個佔位符 – Bastiaan

+0

是的,你是對的。我的錯。 –

+0

我更新我的回答,仍然使用'while_loop'。但我認爲你自己的解決方案更優雅。 +1。附:我發現'map_fn'在內部使用'TensorArray'和'while_loop'。 –

1

中有TF的地圖功能,將工作:

def batch_rot90(batch, k, name=''): 
    fun = lambda x: tf.images.rot90(x, k = 1) 
    return = tf.map_fn(fun, batch)