2017-08-24 187 views
1

道歉,如果這是錯誤的地方提出我的問題(請幫助我與哪裏最好提高它,如果是這樣的話)。我是一個Keras和Python的新手,所以希望迴應有這個想法。如何使用Keras fit_generator批量培訓CNN?

我試圖訓練一個以圖像作爲輸入的CNN轉向模型。這是一個相當大的數據集,所以我創建了一個數據生成器來處理fit_generator()。我不清楚如何使這種方法在批次上進行訓練,所以我假定發生器必須將批次返回到fit_generator()。發電機看起來像這樣:

def gen(file_name, batchsz = 64): 
    csvfile = open(file_name) 
    reader = csv.reader(csvfile) 
    batchCount = 0 
    while True: 
     for line in reader: 
      inputs = [] 
      targets = [] 
      temp_image = cv2.imread(line[1]) # line[1] is path to image 
      measurement = line[3] # steering angle 
      inputs.append(temp_image) 
      targets.append(measurement) 
      batchCount += 1 
      if batchCount >= batchsz: 
       batchCount = 0 
       X = np.array(inputs) 
       y = np.array(targets) 
       yield X, y 
     csvfile.seek(0) 

它讀取包含遙測數據csv文件(轉向角等)和路徑圖像樣本,並返回大小的數組:BATCHSZ 到fit_generator(呼叫)看起來像這樣:

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator 
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator 
try: 
    model.fit_generator(
     tgen, 
     samples_per_epoch=113526, 
     nb_epoch=6, 
     validation_data=vgen, 
     nb_val_samples=20001 
    ) 

數據集包含113526個採樣點還沒有模型訓練更新輸出內容是這樣的(例如):

1020/113526 [..............................] - ETA: 27737s - loss: 0.0080 
    1021/113526 [..............................] - ETA: 27723s - loss: 0.0080 
    1022/113526 [..............................] - ETA: 27709s - loss: 0.0080 
    1023/113526 [..............................] - ETA: 27696s - loss: 0.0080 

哪似乎是按樣本進行訓練樣本(隨機?)。 由此產生的模型是無用的。我以前使用.fit()將整個數據集加載到內存中來訓練一個更小的數據集,並生成了一個至少可以工作的模型,即使效果很差。很明顯,我的fit_generator()方法出了問題。將非常感謝一些幫助。

+1

'samples_per_epoch'應該與[keras文檔](https://keras.io/models/sequential/)中建議的'total_samples/batch_size'相等。 'samples_per_epoch'指定在考慮完成時期之前調用生成器的次數,它不知道你正在使用的是什麼'batch_size' – gionni

+0

Thanks @gionni。從Keras 1.0.2更新到最新版本。這個版本適合生成器()參數更有意義。 – tinyMind

回答

2

此:

for line in reader: 
    inputs = [] 
    targets = [] 

...正在重置您的批處理在CSV文件中的每一行。你不是跟你的整個數據訓練,但只是一個單一的樣品128

建議:

for line in reader: 

    if batchCount == 0: 
     inputs = [] 
     targets = [] 
    .... 
    .... 

正如有人評論說,中配合發電機,samples_per_epoch應等於total_samples/batchsz

儘管如此,我認爲你的損失應該會減少。如果不是,那麼代碼中可能還存在另一個問題,可能是您加載數據的方式,或者模型的初始化或結構中。

嘗試繪製圖像和打印數據的生成:

for X,y in tgen: #careful, this is an infinite loop, make it stop 

    print(X.shape[0]) # is this really the number of batches you expect? 

    for image in X: 
     ...some method to plot X so you can see it, or just print  

    print(y) 

檢查是否產生值確定與你所期望的那樣。

+0

「...正在爲csv文件中的每一行重置您的批處理。」衛生署!應該已經發現了一個。 Wierd是因爲我有一些測試代碼來打印出yeiled數組,並且他們是正確大小和順序的批處理。 – tinyMind

+0

關於損失,最近我遇到了「凍結」損失的問題。我決定一遍又一遍地訓練多個時代的樣本,突然間,這個損失開始下降。然後我逐漸介紹其他例子,並開始正確訓練。我猜這個模型太複雜了,或者我沒有正確初始化權重,所以顯示一些有趣的結果需要更長的時間。 –

+0

謝謝丹尼爾。似乎現在訓練好了。雖然GPU負載很低,就好像GPU正在等待腳本一樣。 – tinyMind