2015-12-31 65 views
3

我試圖在TensorFlow在cifar10實施例中描述來讀取以類似的方式標號:讀取多個字節爲一個值在TensorFlow

.... 
label_bytes = 2 # it was 1 in the original version 
result.key, value = reader.read(filename_queue) 
record_bytes = tf.decode_raw(value, tf.uint8) 
result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 
.... 

的問題是,如果是label_byte大於1 (例如2),result.label似乎變成兩個元素(每個元素都是1個字節)的張量。我只想將連續的label_bytes字節表示爲單個值。我怎麼做?

感謝

回答

3

用它創建第二個解碼器,解碼INT16和採取的第一個元素爲您的標籤

shorts = tf.decode_raw(value, tf.int16) 
result.label = tf.cast(shorts[0], tf.int32) 

有可能是一個更好的解決方案,但它的工作原理。

+0

謝謝,這也是我想出來的,我想這對於解碼兩次來說太浪費了。我仍然在尋找更好的解決方案。 – Zk1001

相關問題