2017-09-24 41 views
2

我使用keras實現了一個分類程序。我有一大套圖像,我想用for循環預測每個圖像。keras預測內存交換無限期增加

但是,每次計算新圖像時,交換內存都會增加。我試圖刪除預測函數內的所有變量(並且我確信它在這個函數內部存在問題),但內存仍然增加。

for img in images: 
    predict(img, model, categ_par, gl_par) 

和相應的功能:

def predict(image_path, model, categ_par, gl_par): 
    print("[INFO] loading and preprocessing image...") 

    orig = cv2.imread(image_path) 

    image = load_img(image_path, target_size=(gl_par.img_width, gl_par.img_height)) 
    image = img_to_array(image) 

    # important! otherwise the predictions will be '0' 
    image = image/255 

    image = np.expand_dims(image, axis=0) 

    # build the VGG16 network 
    if(categ_par.method == 'VGG16'): 
     model = applications.VGG16(include_top=False, weights='imagenet') 

    if(categ_par.method == 'InceptionV3'): 
     model = applications.InceptionV3(include_top=False, weights='imagenet') 

    # get the bottleneck prediction from the pre-trained VGG16 model 
    bottleneck_prediction = model.predict(image) 

    # build top model 
    model = Sequential() 
    model.add(Flatten(input_shape=bottleneck_prediction.shape[1:])) 
    model.add(Dense(256, activation='relu')) 
    model.add(Dropout(0.5)) 
    model.add(Dense(categ_par.n_class, activation='softmax')) 

    model.load_weights(categ_par.top_model_weights_path) 

    # use the bottleneck prediction on the top model to get the final classification 
    class_predicted = model.predict_classes(bottleneck_prediction) 
    probability_predicted = (model.predict_proba(bottleneck_prediction)) 

    classe = pd.DataFrame(list(zip(categ_par.class_indices.keys(), list(probability_predicted[0])))).\ 
    rename(columns = {0:'type', 1: 'prob'}).reset_index(drop=True) 
    #print(classe) 
    del model 
    del bottleneck_prediction 
    del image 
    del orig 
    del class_predicted 
    del probability_predicted 

    return classe.set_index(['type']).T 
+1

每次進行預測時,您似乎都在構建新模型。你確定你想要嗎? –

回答

2

如果您正在使用TensorFlow後端,你將建立每個IMG的模型在for循環。 TensorFlow只是將圖形附加到圖形等上,這意味着內存只是上升。這是一個衆所周知的事件,必須在超參數優化期間處理,當您要構建多個模型時,還需要在此處進行處理。

from keras import backend as K 

,並把這個在年底預測():

K.clear_session() 

或者你也可以建立一個模型,並且把這個作爲輸入預測功能,這樣你就不會建立每個新時間。