0
我想使用CNN來解決deblurring的任務,並且我有訓練數據,它是一個包含文件名的png圖像和相應文本文件的目錄。在Tensorflow中傳輸CNN模型時,如何從目錄讀取圖像作爲輸入和輸出?
由於數據太大而無法用一步添加到內存中,並且是否有任何API或某種方法使我可以讀取blury圖像作爲輸入,並將其作爲預期結果讀取到地面實況以訓練?
我花了不少時間來解決這個問題,但我糊塗了在網上API介紹後讀的API。
我想使用CNN來解決deblurring的任務,並且我有訓練數據,它是一個包含文件名的png圖像和相應文本文件的目錄。在Tensorflow中傳輸CNN模型時,如何從目錄讀取圖像作爲輸入和輸出?
由於數據太大而無法用一步添加到內存中,並且是否有任何API或某種方法使我可以讀取blury圖像作爲輸入,並將其作爲預期結果讀取到地面實況以訓練?
我花了不少時間來解決這個問題,但我糊塗了在網上API介紹後讀的API。
該方法並不困惑。 tensorflow提供TFrecords文件以充分利用內存。
def create_cord():
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index in xrange(66742):
blur_file_name = get_file_name(index, True)
orig_file_name = get_file_name(index, False)
blur_image_path = cwd + blur_file_name
orig_image_path = cwd + orig_file_name
blur_image = Image.open(blur_image_path)
orig_image = Image.open(orig_image_path)
blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
blur_image_raw = blur_image.tobytes()
orig_image_raw = orig_image.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
讀取數據集:
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'blur_image_raw': tf.FixedLenFeature([], tf.string),
'orig_image_raw': tf.FixedLenFeature([], tf.string),
})
blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
blur_img = tf.cast(blur_img, tf.float32) * (1./255) - 0.5
orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
orig_img = tf.cast(orig_img, tf.float32) * (1./255) - 0.5
return blur_img, orig_img
if __name__ == '__main__':
# create_cord()
blur, orig = read_and_decode("train.tfrecords")
blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
batch_size=3, capacity=1000,
min_after_dequeue=100)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 啓動隊列
threads = tf.train.start_queue_runners(sess=sess)
for i in range(3):
v, l = sess.run([blur_batch, orig_batch])
print(v.shape, l.shape)
您正在尋找這樣的:http://stackoverflow.com/a/36947632/2505209?以圖片爲例以及標籤。 – hars