2017-04-13 56 views
5

我正在嘗試執行與tf.decode_raw相反的操作。如何創建一個encode_raw tensorflow函數?

一個例子會給出dtype = tf.float32的張量,我想要一個函數encode_raw(),它接受一個浮點張量並返回一個字符串類型的張量。

這很有用,因爲我可以使用tf.write_file來寫入文件。

有誰知道如何使用現有函數在Tensorflow中創建這樣的功能?

回答

2

我會建議用tf.as_string作爲文本編寫數字。如果你真的想給他們寫一個二進制字符串,然而,事實證明是可行的:

import tensorflow as tf 

with tf.Graph().as_default(): 
    character_lookup = tf.constant([chr(i) for i in range(256)]) 
    starting_dtype = tf.float32 
    starting_tensor = tf.random_normal(shape=[10, 10], stddev=1e5, 
            dtype=starting_dtype) 
    as_string = tf.reduce_join(
     tf.gather(character_lookup, 
       tf.cast(tf.bitcast(starting_tensor, tf.uint8), tf.int32))) 
    back_to_tensor = tf.reshape(tf.decode_raw(as_string, starting_dtype), 
           [10, 10]) # Shape information is lost 
    with tf.Session() as session: 
    before, after = session.run([starting_tensor, back_to_tensor]) 
    print(before - after) 

這對我打印全零的數組。

+0

它的工作原理!謝謝! –

1

對於那些與Python 3個工作:

CHR()具有在Python 3不同的行爲改變與來自先前答案的代碼所獲得的字節輸出。 更換這行代碼

character_lookup = tf.constant([chr(i) for i in range(256)])

character_lookup = tf.constant([i.tobytes() for i in np.arange(256, dtype=np.uint8)])

修復此問題。

相關問題