2017-08-05 300 views
1

我正在進行面部表情識別,我正在使用Keras。我收集了許多數據集,然後在圖像上應用了數據增強功能,在.csv文件(與fer2013.csv格式相同)上保存了約500 000個圖像(以像素爲單位)。Keras處理無法放入內存的大型數據集

這是我使用的代碼:

def Zerocenter_ZCA_whitening_Global_Contrast_Normalize(list): 
    Intonumpyarray = numpy.asarray(list) 
    data = Intonumpyarray.reshape(img_width,img_height) 
    data2 = ZeroCenter(data) 
    data3 = zca_whitening(flatten_matrix(data2)).reshape(img_width,img_height) 
    data4 = global_contrast_normalize(data3) 
    data5 = numpy.rot90(data4,3) 
    return data5 



def load_data(): 
    train_x = [] 
    train_y = [] 
    val_x = [] 
    val_y = [] 
    test_x = [] 
    test_y = [] 

    f = open('ALL.csv') 
    csv_f = csv.reader(f) 

    for row in csv_f: 
     if str(row[2]) == "Training": 
      temp_list_train = [] 

      for pixel in row[1].split(): 
       temp_list_train.append(int(pixel)) 

      data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_train) 
      train_y.append(int(row[0])) 
      train_x.append(data.reshape(data_resh).tolist()) 

     elif str(row[2]) == "PublicTest": 
      temp_list_validation = [] 

      for pixel in row[1].split(): 
       temp_list_validation.append(int(pixel)) 

      data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_validation) 
      val_y.append(int(row[0])) 
      val_x.append(data.reshape(data_resh).tolist()) 

     elif str(row[2]) == "PrivateTest": 
      temp_list_test = [] 

      for pixel in row[1].split(): 
       temp_list_test.append(int(pixel)) 

      data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_test) 
      test_y.append(int(row[0])) 
      test_x.append(data.reshape(data_resh).tolist()) 

    return train_x, train_y, val_x, val_y, test_x, test_y 

然後我加載數據,並將它們提供給發電機:

Train_x, Train_y, Val_x, Val_y, Test_x, Test_y = load_data() 

Train_x = numpy.asarray(Train_x) 
Train_x = Train_x.reshape(Train_x.shape[0],img_rows,img_cols) 

Test_x = numpy.asarray(Test_x) 
Test_x = Test_x.reshape(Test_x.shape[0],img_rows,img_cols) 

Val_x = numpy.asarray(Val_x) 
Val_x = Val_x.reshape(Val_x.shape[0],img_rows,img_cols) 

Train_x = Train_x.reshape(Train_x.shape[0], img_rows, img_cols, 1) 
Test_x = Test_x.reshape(Test_x.shape[0], img_rows, img_cols, 1) 
Val_x = Val_x.reshape(Val_x.shape[0], img_rows, img_cols, 1) 

Train_x = Train_x.astype('float32') 
Test_x = Test_x.astype('float32') 
Val_x = Val_x.astype('float32') 

Train_y = np_utils.to_categorical(Train_y, nb_classes) 
Test_y = np_utils.to_categorical(Test_y, nb_classes) 
Val_y = np_utils.to_categorical(Val_y, nb_classes) 


datagen = ImageDataGenerator(
    featurewise_center=False, 
    samplewise_center=False, 
    featurewise_std_normalization=False, 
    samplewise_std_normalization=False, 
    zca_whitening=False, 
    width_shift_range=0.2, 
    height_shift_range=0.2, 
    horizontal_flip=True, 
    shear_range=0.03, 
    zoom_range=0.03, 
    vertical_flip=False) 

datagen.fit(Train_x) 

model.fit_generator(datagen.flow(Train_x, Train_y, 
    batch_size=batch_size), 
    samples_per_epoch=Train_x.shape[0], 
    nb_epoch=nb_epoch, 
    validation_data=(Val_x, Val_y)) 

當我運行的代碼,RAM使用率越來越大,更大,直到電腦死機(我有16 Gb)。 load_data()被調用時它會卡住。任何解決這個問題,可以適合我的代碼?

+0

您需要編寫一個生成器函數,將csv文件的某些行加載到RAM中[一個很好的示例](https://github.com/fchollet/keras/issues/2708)。您一次加載太多數據 – DJK

回答

1

似乎是this question的副本。基本上,您必須使用fit_generator()而不是fit(),並傳入一個函數,該函數一次一個批量地將數據加載到您的模型中,而不是一次全部加載到您的模型中。