2017-03-04 59 views
4

我想在TensorFlow中創建一個簡單的神經網絡。唯一棘手的部分是我有一個自定義的操作,我用py_func實現。當我將輸出從py_func傳遞到Dense圖層時,TensorFlow抱怨排名應該是已知的。特定的錯誤是:TensorFlow輸出`py_func`具有未知的排名/形狀

ValueError: Inputs to `Dense` should have known rank. 

我不知道如何保持自己的數據的形狀,當我通過它通過py_func。我的問題是如何獲得正確的形狀?我有一個簡單的例子來說明這個問題。

def my_func(x): 
    return np.sinh(x).astype('float32') 

inp = tf.convert_to_tensor(np.arange(5)) 
y = tf.py_func(my_func, [inp], tf.float32, False) 

with tf.Session() as sess: 
    with sess.as_default(): 
     print(inp.shape) 
     print(inp.eval()) 
     print(y.shape) 
     print(y.eval()) 

從這個片斷的輸出是:

(5,) 
[0 1 2 3 4] 
<unknown> 
[ 0.  
1.17520118 3.62686038 10.01787472 27.28991699] 

爲什麼y.shape<unknown>?我想要的形狀是(5,)inp相同。謝謝!

+0

的可能的複製[Tensorflow:Py \ _func返回未知形狀](https://stackoverflow.com/questions/38992445/tensorflow-py-func-returns-unknown-shape) – gkcn

+0

@gkcn也許,自從我問了一段時間後,我發現了一段時間提交人回答自己的這個問題。我記得他的解決方案不適合我,這就是爲什麼我寫了我的問題。 –

回答

8

由於py_func可以執行任意Python代碼的輸出任何東西,TensorFlow想不通的形狀(它需要分析函數體的Python代碼)可以代替給形狀手動

y.set_shape(inp.get_shape())