2016-07-27 38 views
0

我使用的代碼從this github,以下this教程。 有幾個變化,我做了,因爲我正在對我的數據進行CNN培訓。但是,可能是我在'create_lmdb.py'文件中執行的更改存在問題。這兩個數據庫之間的區別是:錯誤與create_lmdb.py文件

第一:我正在訓練我的網絡與32x32圖像。 秒:我的數據庫只包含灰度圖像。 但我也訓練我的網絡二進制分類。

修改後,這是我的文件:

import os 
import glob 
import random 
import numpy as np 

import cv2 

import caffe 
from caffe.proto import caffe_pb2 
import lmdb 

#Size of images 
IMAGE_WIDTH = 32 
IMAGE_HEIGHT = 32 

def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT): 

    #Histogram Equalization 
    img = cv2.equalizeHist(img) 
    #img[:, :, 1] = cv2.equalizeHist(img[:, :, 1]) not a RGB 
    #img[:, :, 2] = cv2.equalizeHist(img[:, :, 2]) 

    #Image Resizing 
    img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC) # make sure all the images are at the same size 

    return img 


def make_datum(img, label): 
    #image is numpy.ndarray format. BGR instead of RGB 
    return caffe_pb2.Datum(
     channels=1, #not an RGB image 
     width=IMAGE_WIDTH, 
     height=IMAGE_HEIGHT, 
     label=label, 
     data=img.tostring()) 

train_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/train_lmdb' 
validation_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/validation_lmdb' 

os.system('rm -rf ' + train_lmdb) 
os.system('rm -rf ' + validation_lmdb) 


train_data = [img for img in glob.glob("../input/train/*png")] 
test_data = [img for img in glob.glob("../input/test1/*png")] 

#Shuffle train_data 
random.shuffle(train_data) 

print 'Creating train_lmdb' 

in_db = lmdb.open(train_lmdb, map_size=int(1e12)) 
with in_db.begin(write=True) as in_txn: 
    for in_idx, img_path in enumerate(train_data): 
     if in_idx % 6 == 0: 
      continue 
     img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 
     img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT) 
     if 'cat' in img_path: 
      label = 0 
     else: 
      label = 1 
     datum = make_datum(img, label) 
     in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) 
     print '{:0>5d}'.format(in_idx) + ':' + img_path 
in_db.close() 


print '\nCreating validation_lmdb' 

in_db = lmdb.open(validation_lmdb, map_size=int(1e12)) 
with in_db.begin(write=True) as in_txn: 
    for in_idx, img_path in enumerate(train_data): 
     if in_idx % 6 != 0: 
      continue 
     img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 
     img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT) 

     prec=int(img_path[(img_path.index('prec_')+5):(img_path.index('prec_')+8)]) 

     if prec>50: 
      label = 1 
     else: 
      label = 0 

     datum = make_datum(img, label) 
     in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) 
     print '{:0>5d}'.format(in_idx) + ':' + img_path 
in_db.close() 

print '\nFinished processing all images' 

但我認爲,根據訓練的結果:.MDB輸出文件已損壞(可能爲空或東西 - 即使它的權重47MB)。

任何人都可以看到這個文件有問題嗎?或者,或者給我一個關於構建lmdb文件的好教程的鏈接?

非常感謝您的幫助! 謝謝

回答

0

好的,我解決了! 尋找更深的代碼後,我發現我只是驗證數據集的標籤更新(並跳過訓練數據):P 它可以在這一段代碼可以看到:

img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT) 
     if 'cat' in img_path: 
      label = 0 
     else: 
      label = 1 

屬於原創教程。

結論:如果你不能訪問你的lmdb文件,馬貝是因爲創建它的函數被打破。

0

如果你想創建一個'lmdb'圖像數據集來訓練分類網絡,不要出汗! Caffe已經有了一個專門用於這個目的的工具!
您正在尋找$CAFFE_ROOT/build/tools/convert_imageset工具,你可以找到很詳細的(如果我可以這麼說;)教程here

+0

非常感謝。我開始建立一個新文件,然後發現我的錯誤:P – roishik