2017-10-06 107 views
0

在創建和加載.tfrecord文件的情況下我遇到了以下問題:tf.contrib.data.TFRecordDataset無法讀取* .tfrecord

生成dataset.tfrecord文件

的文件夾/ Batch_manager /資產包含了一些*。TIF被用來生成一個dataset.tfrecord文件圖片:

def _save_as_tfrecord(self, path, name): 
    self.__filename = os.path.join(path, name + '.tfrecord') 
    writer = tf.python_io.TFRecordWriter(self.__filename) 
    print('Writing', self.__filename) 
    for index, img in enumerate(self.load(get_iterator=True, n_images=1)): 
     img = img[0] 
     image_raw = img.tostring() 
     rows = img.shape[0] 
     cols = img.shape[1] 
     try: 
      depth = img.shape[2] 
     except IndexError: 
      depth = 1 
     example = tf.train.Example(features=tf.train.Features(feature={ 
      'height': self._int64_feature(rows), 
      'width': self._int64_feature(cols), 
      'depth': self._int64_feature(depth), 
      'label': self._int64_feature(int(self.target[index])), 
      'image_raw': self._bytes_feature(image_raw) 
       })) 
     writer.write(example.SerializeToString()) 
    writer.close() 

從dataset.tfrecord文件中讀取

下一頁我嘗試在那裏的道路走向的dataset.tfrecord文件引導到使用此文件中讀取:

def dataset_input_fn(self, path): 
    dataset = tf.contrib.data.TFRecordDataset(path) 

    def parser(record): 
     keys_to_features = { 
      "height": tf.FixedLenFeature((), tf.int64, default_value=""), 
      "width": tf.FixedLenFeature((), tf.int64, default_value=""), 
      "depth": tf.FixedLenFeature((), tf.int64, default_value=""), 
      "label": tf.FixedLenFeature((), tf.int64, default_value=""), 
      "image_raw": tf.FixedLenFeature((), tf.string, default_value=""), 
     } 
     print(record) 
     features = tf.parse_single_example(record, features=keys_to_features) 
     print(features) 
     label = features['label'] 
     height = features['height'] 
     width = features['width'] 
     depth = features['depth'] 
     image = tf.decode_raw(features['image_raw'], tf.float32) 
     image = tf.reshape(image, [height, width, -1]) 
     label = tf.cast(features["label"], tf.int32) 

     return {"image_raw": image, "height": height, "width": width, "depth":depth, "label":label} 

    dataset = dataset.map(parser) 
    dataset = dataset.shuffle(buffer_size=10000) 
    dataset = dataset.batch(32) 
    iterator = dataset.make_one_shot_iterator() 

    # `features` is a dictionary in which each value is a batch of values for 
    # that feature; `labels` is a batch of labels. 
    features = iterator.get_next() 

    return Features 

錯誤消息:

類型錯誤:預期的Int64,STR「而不是「類型」了」。

什麼是錯的這一段代碼?我成功驗證了dataset.tfrecord實際上包含正確的圖像和元數據!

+0

'self.load(...)簡單地'返回可用於在每個圖片基礎加載的迭代器。我'敢肯定,這個問題是要麼是因爲我建'example'變量和寫入的方式來dataset.tfrecord或者進行解析'tf.contrib.data.TFRecordDataset(路徑)的方式'和'的解析器「函數給'.map(func)' –

回答

0

的錯誤發生,因爲我複製並粘貼該實施例中,其設置所有的鍵 - 值對的值爲空字符串,致default_value=""。去除所有tf.FixedLenFeature固定的問題。