2017-09-20 42 views
1

我想在keras中實現一個暹羅網絡,我想使用Keras圖像數據生成器將圖像轉換應用於2個輸入圖像。按照在docs- https://keras.io/preprocessing/image/的例子中,我試圖實現它像這個 -我如何結合兩個keras生成器函數

datagen_args = dict(rotation_range=10, 
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True) 

in_gen1 = ImageDataGenerator(**datagen_args) 
in_gen2 = ImageDataGenerator(**datagen_args) 

train_generator = zip(in_gen1, in_gen2) 

model.fit(train_generator.flow([pair_df[:, 0,::],pair_df[:, 1,::]], 
          y_train,batch_size=16), epochs, verbose = 1) 

但這個代碼拋出這個錯誤:

類型錯誤:ZIP參數#1必須支持迭代

我試過使用建議使用itertools.izip,但是這會引發同樣的錯誤。

我該如何解決這個問題?

編輯:如果有人有興趣,這個工作finally-

datagen_args = dict(
    featurewise_center=False, 
    rotation_range=10, 
    width_shift_range=0.1, 
    height_shift_range=0.1, 
    horizontal_flip=True) 

in_gen1 = ImageDataGenerator(**datagen_args) 
in_gen2 = ImageDataGenerator(**datagen_args) 

in_gen1 = in_gen1.flow(pair_df[:, 0,::], y_train, batch_size = 16, shuffle = False) 
in_gen2 = in_gen2.flow(pair_df[:, 1,::], y_train, batch_size = 16, shuffle = False) 

for e in range(epochs): 
    batches = 0 
    for x1, x2 in itertools.izip(in_gen1,in_gen2): 
    # x1, x2 are tuples returned by the generator, check whether targets match 
     assert sum(x1[1] != x2[1]) == 0 
     model.fit([x1[0], x2[0]], x1[1], verbose = 1) 
     batches +=1 
     if(batches >= len(pair_df)/16): 
      break 

回答

1

您需要先使用流量法迭代它們轉換的東西。

嘗試以下操作:

datagen_args = dict(rotation_range=10, 
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True) 

in_gen1 = ImageDataGenerator(**datagen_args) 
in_gen2 = ImageDataGenerator(**datagen_args) 

gen1_flow = in_gen1.flow(X_train[:,0, ::],y_train, batch_size=16) 
gen2_flow = in_gen2.flow(X_train[:,1, ::],y_train, batch_size=16) 

train_generator = zip(gen1_flow, gen2_flow) 

model.fit_generator(train_generator, 
        steps_per_epoch=len(X_train)/16, 
        epochs=epochs) 
+0

謝謝!這工作。 – azure31