2
的選擇分支我有一個多輸出Keras模型類似於此的結構:列車多輸出keras模型
s = some_shared_layers()(input)
non_trainable1 = Dense(trainable=False) (s)
non_trainable2 = Dense(trainable=False) (s)
trainable = Dense() (s)
model = Model(input, outputs=[non_trainable1, non_trainable2, trainable])
我的模型首先計算一個直傳並使用該第一2個輸出以操縱輸入。然後計算另一個正向傳球以獲得第三個輸出。
out1, out2,_ =model.predict(input_data)
processed_data = foo(input_data, out1, out2)
_,_, out3 = model.predict(processed_data)
我應該如何調用model.fit()
訓練只有trainable
層?如果我排除其他產出的損失,凱拉斯警告we will not be expecting any data to be passed to "non_trainable1" during training
並將其從計算圖中排除。
有沒有更好的方法來構建這個任務的模型?
我目前使用作爲一種變通方法,它如果我按照您所描述的運行代碼,則會起作用。但是,如果我嘗試以更復雜的方式使用處理函數'foo' - 比如在'ImageDataGenerator'內部 - 然後'model2.fit_generator()'產生一個'ValueError:Tensor'non_trainable1「不是該圖的一個元素。' – Manas
爲什麼在發電機內?你在發電機裏究竟想要做什麼?這聽起來你正在使用張量而不是使用預測。 –
我正在處理圖像,並正在使用生成器進行數據增強。作爲一個天真的第一次嘗試,我嘗試傳遞'foo'作爲'preprocessing_function'參數(ref [docs](https://keras.io/preprocessing/image/))。 'foo'拍攝圖像,調用'model1.predict()'並返回一個編輯後的圖像。 – Manas