2017-06-06 103 views
3

我使用CNTK作爲Keras的後端。我正在嘗試使用我在C++中使用Keras進行訓練的模型。Keras-CNTK保存模型-v2格式

我已經訓練並保存了我的模型,使用了HDF5中的Keras。我現在如何使用CNTK API將其保存爲model-v2格式?

我嘗試這樣做:

model = load_model('model2.h5') 
cntk.ops.functions.Function.save(model, 'CNTK_model2.pb') 

,但我得到了以下錯誤:

TypeError: save() missing 1 required positional argument: 'filename' 

如果tensorflow是後端我會做這樣的:

model = load_model('model2.h5') 
sess = K.get_session() 
tf_saver = tf.train.Saver() 
tf_saver.save(sess=sess, save_path=checkpoint_path) 

我怎樣才能達到同樣的目的?

回答

0

你可以做這樣的事情 model.outputs[0].save('CNTK_model2.pb') 我假設你在這裏呼籲model.compile(即這是我曾嘗試:-)

+0

我太累了,@nikosk: - '模型= load_model( 'model2.h5') model.compile(虧損= 'categorical_crossentropy', 優化= '亞當', 度量= [ '準確性']) model.outputs [0] .save( 'CNTK_model2.pb')' 但加載模型時: - 'C.load_model( 'CNTK_model2.pb')' 我得到了以下錯誤: - 'if is_file: return cntk_py.Function.load(model,devic e) raise ValueError('無法加載既不是文件也不是字節緩衝區的模型') RuntimeError:SWIG導向器方法錯誤。 – Lenni

3

按照意見here唯一的情況下,我能使用此:

import cntk as C 
import keras.backend as K 

keras_model = K.load_model('my_keras_model.h5') 

C.combine(keras_model.model.outputs).save('my_cntk_model') 
cntk_model = C.load_model('my_cntk_model') 
0

你看到這個錯誤是因爲keras' cntk後端使用用戶定義的函數分批軸,這是不能被序列做重塑的原因。我們已經在CNTK v2.2中解決了這個問題。請將您的cntk升級到v2.2,並將keras升級爲最後一個主控。 請參閱本拉請求: https://github.com/fchollet/keras/pull/7907