2017-07-29 80 views
0

我想在這裏顯然執行以下代碼如何正確獲取Tensorflow中的形狀,以便我可以再次成形?

def f(x): 
    (_, H, W, C) = tf.shape(x) 
    x_reshaped = tf.reshape(x, (-1,C)) 
    res = x_reshaped/(H*W*C) 
    return res 

但是,問題是,我不知道H,W在高級所以他們?,?。所以重塑和倍增不起作用。現在我的問題是,如何正確執行上述計算,以便res是一個正確的tensorflow節點,可以在會話中稍後進行計算?

回答

1

下面應該工作:

X = tf.placeholder(tf.float32, shape=[None, None, None, 40]) 

def f(x): 
    s = tf.shape(x) 
    x_reshaped = tf.reshape(x, [-1,s[3]]) 
    res = tf.div(x_reshaped, tf.cast((s[0]*s[1]*s[2]), tf.float32)) 
    return res 

out = f(X) 

sess = tf.Session() 
sess.run(out, {X:np.random.normal(size=(10,20,30,40))}) 
0

我假設你想x是形狀(batch_size, H*W*C),這意味着每x項目是「扁平化」的圖像數據。在這種情況下,正確的代碼將是:

x_reshaped = tf.reshape(x, (-1, H*W*C)) 

但沒有看到更多的代碼,我不能確定。例如,如果您的神經網絡設計爲卷積,則根本不需要重塑

相關問題