2017-10-11 57 views
0

我想在另一個Keras網絡(B)內使用Keras網絡(A)。我首先訓練網絡A.然後我在網絡B中使用它來執行一些正則化。內部網絡B我想用evaluatepredict來從網絡A得到輸出。不幸的是,我一直無法得到這個工作,因爲這些函數需要一個numpy數組,而不是接收一個Tensorflow變量作爲輸入。keras正向傳遞與張量變量作爲輸入

這裏是我如何使用自定義正則內部網絡答:

class CustomRegularizer(Regularizer): 
    def __init__(self, model): 
     """model is a keras network""" 
     self.model = model 

    def __call__(self, x): 
     """Need to fix this part""" 
     return self.model.evaluate(x, x) 

我如何計算與Keras網絡與Tensorflow變量作爲輸入向前傳球?

作爲一個例子,這裏就是我與numpy的:

x = np.ones((1, 64), dtype=np.float32) 
model.predict(x)[:, :10] 

輸出:

array([[-0.0244251 , 3.31579041, 0.11801113, 0.02281714, -0.11048832, 
     0.13053198, 0.14661783, -0.08456061, -0.0247585 , 
0.02538805]], dtype=float32) 

隨着Tensorflow

x = tf.Variable(np.ones((1, 64), dtype=np.float32)) 
model.predict_function([x]) 

輸出:

--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-92-4ed9d86cd79d> in <module>() 
     1 x = tf.Variable(np.ones((1, 64), dtype=np.float32)) 
----> 2 model.predict_function([x]) 

~/miniconda/envs/bolt/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs) 
    2266   updated = session.run(self.outputs + [self.updates_op], 
    2267        feed_dict=feed_dict, 
-> 2268        **self.session_kwargs) 
    2269   return updated[:len(self.outputs)] 
    2270 

~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 
    776  try: 
    777  result = self._run(None, fetches, feed_dict, options_ptr, 
--> 778       run_metadata_ptr) 
    779  if run_metadata: 
    780   proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

~/miniconda/envs/bolt/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 
    952    np_val = subfeed_val.to_numpy_array() 
    953   else: 
--> 954    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 
    955 
    956   if (not is_tensor_handle_feed and 

~/miniconda/envs/bolt/lib/python3.6/site-packages/numpy/core/numeric.py in asarray(a, dtype, order) 
    529 
    530  """ 
--> 531  return array(a, dtype, copy=False, order=order) 
    532 
    533 

ValueError: setting an array element with a sequence. 

回答

0

我不知道在哪裏的tensorflow變量進來的,但如果它的存在,你可以這樣做:

model.predict([sess.run(x)]) 

其中sess是tensorflow會議,即sess = tf.Session()

+0

我添加了上下文以瞭解網絡如何用於我的問題。我還沒有能夠調整你的答案來解決我的問題。 –

+0

對不起,但我認爲還需要更多的細節來幫助你調試。我唯一能想到的是你可以嘗試'cr([sess.run(x)])'和'cr = CustomRegularizer(model)'。 –