2015-12-04 49 views
6

我仍然試圖用自己的圖像數據運行Tensorflow。 我能夠從這個例子link在Tensorflow中使用字符串標籤

創建具有conevert_to()函數.tfrecords文件現在我,我想用代碼的網絡免受例如link訓練。

但它在read_and_decode()函數中失敗。我在功能的變化是:

label = tf.decode_raw(features['label'], tf.string) 

的錯誤是:

TypeError: DataType string for attr 'out_type' not in list of allowed values: float32, float64, int32, uint8, int16, int8, int64 

那麼如何1)閱讀和2)使用字符串標籤中tensorflow培訓。

回答

4

convert_to_records.py腳本會創建一個.tfrecords文件,其中每個記錄是一個Example協議緩衝區。該協議緩衝區支持使用bytes_list kind的字符串功能。

tf.decode_raw op用於將二進制字符串解析爲圖像數據;它不是爲了解析字符串(文本)標籤而設計的。假設features['label']是一個tf.string張量,您可以使用tf.string_to_number op將其轉換爲數字。在TensorFlow程序中對字符串處理的其他支持有限,所以如果您需要執行一些更復雜的函數來將字符串標籤轉換爲整數,那麼您應該在修改後的版本convert_to_tensor.py的Python中執行此轉換。

+0

是'string_to_number'只是爲了將_numeric_字符串轉換爲數字,但?我得到一個任意字符串值的異常(即「test」),而'tf.string_to_number(「20」)'可以正常工作,併產生一個'20.0'tf.float32'張量。 –

+1

是的。如果您有文本字符串標籤並需要將其轉換爲數字,則可以使用['tf.feature_column.categorical_column _ *()'] (https://www.tensorflow.org/api_docs/python/tf/feature_column)API,例如['tf.feature_column.categorical_column_with_vocabulary_list()'](https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_list)或['tf.feature_column.categorical_column_with_hash_bucket()' ](https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_hash_bucket)。 – mrry

1

爲了增加@mrry的答案,假設你的字符串是ascii,您可以:

def _bytes_feature(value): 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 

def write_proto(cls, filepath, ..., item_id): # itemid is an ascii encodable string 
    # ... 
    with tf.python_io.TFRecordWriter(filepath) as writer: 
     example = tf.train.Example(features=tf.train.Features(feature={ 
      # write it as a bytes array, supposing your string is `ascii` 
      'item_id': _bytes_feature(bytes(item_id, encoding='ascii')), # python 3 
      # ... 
     })) 
     writer.write(example.SerializeToString()) 

然後:

def parse_single_example(cls, example_proto, graph=None): 
    features_dict = tf.parse_single_example(example_proto, 
     features={'item_id': tf.FixedLenFeature([], tf.string), 
     # ... 
     }) 
    # decode as uint8 aka bytes 
    instance.item_id = tf.decode_raw(features_dict['item_id'], tf.uint8) 

,然後當你得到它回到你的會話,變換回到字符串:

item_id, ... = session.run(your_tfrecords_iterator.get_next()) 
print(str(item_id.flatten(), 'ascii')) # python 3 

我把uint8把戲從這個related so answer。適用於我,但歡迎評論/改進。

+0

我有一個TFRecord組成的圖像,其中一個功能是磁盤上該映像的路徑。路徑的形式爲'path \ to \ images \ image432.jpg'。該路徑的長度從「88」到「91」不等。當我解碼這個特殊的功能爲'tf.decode_raw(features ['train/path'],tf.uint8)',我得到 'ValueError:所有形狀必須完全定義:[TensorShape([Dimension(None)]) ,TensorShape([Dimension(256),Dimension(256),Dimension(1)]),TensorShape([])]',第一維對應於路徑 – dpk