2017-05-31 44 views
2

的選擇分支我有一個多輸出Keras模型類似於此的結構:列車多輸出keras模型

s = some_shared_layers()(input) 
non_trainable1 = Dense(trainable=False) (s) 
non_trainable2 = Dense(trainable=False) (s) 
trainable = Dense() (s) 

model = Model(input, outputs=[non_trainable1, non_trainable2, trainable]) 

我的模型首先計算一個直傳並使用該第一2個輸出以操縱輸入。然後計算另一個正向傳球以獲得第三個輸出。

out1, out2,_ =model.predict(input_data) 
processed_data = foo(input_data, out1, out2) 
_,_, out3 = model.predict(processed_data) 

我應該如何調用model.fit()訓練只有trainable層?如果我排除其他產出的損失,凱拉斯警告we will not be expecting any data to be passed to "non_trainable1" during training並將其從計算圖中排除。

有沒有更好的方法來構建這個任務的模型?

回答

0

如果我正確理解它,你根本不需要這些層,事實上,你應該有兩個模型,一個用於預測,另一個用於訓練。

非可訓練:

model1 = Model(input, [non_trainable1, non_trainable2]) 
#model 1 doesn't need to be compiled, since you won't train it  

可訓練:

model2 = Model(input, trainable) 
model2.compile(loss=onlyTheLossForTrainable)  

使用它們:

out1, out2 =model1.predict(input_data) 
processed_data = foo(input_data, out1, out2) 

model2.fit(processed_data, expected_outputs, ....)  
+0

我目前使用作爲一種變通方法,它如果我按照您所描述的運行代碼,則會起作用。但是,如果我嘗試以更復雜的方式使用處理函數'foo' - 比如在'ImageDataGenerator'內部 - 然後'model2.fit_generator()'產生一個'ValueError:Tensor'non_trainable1「不是該圖的一個元素。' – Manas

+0

爲什麼在發電機內?你在發電機裏究竟想要做什麼?這聽起來你正在使用張量而不是使用預測。 –

+0

我正在處理圖像,並正在使用生成器進行數據增強。作爲一個天真的第一次嘗試,我嘗試傳遞'foo'作爲'preprocessing_function'參數(ref [docs](https://keras.io/preprocessing/image/))。 'foo'拍攝圖像,調用'model1.predict()'並返回一個編輯後的圖像。 – Manas