2017-07-14 33 views
1

工作學習的榜樣tf.contrib.distributions.MultivariateNormalDiag的參數:無法通過優化

import numpy as np 
import tensorflow as tf 

## construct data 
np.random.seed(723888) 
N,P = 50,3 # number and dimensionality of observations 
Xbase = np.random.multivariate_normal(mean=np.zeros((P,)), cov=np.eye(P), size=N) 

## construct model 
X  = tf.placeholder(dtype=tf.float32, shape=(None, P), name='X') 
mu  = tf.Variable(np.random.normal(loc=0.0, scale=0.1, size=(P,)), dtype=tf.float32, name='mu') 
xDist = tf.contrib.distributions.MultivariateNormalDiag(loc=mu, scale_diag=tf.ones(shape=(P,), dtype=tf.float32), name='xDist') 
xProbs = xDist.prob(X, name='xProbs') 

## prepare optimizer 
eta  = 1e-3 # learning rate 
loss  = -tf.reduce_mean(tf.log(xProbs), name='loss') 
optimizer = tf.train.AdamOptimizer(learning_rate=eta).minimize(loss) 

## launch session 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    sess.run(optimizer, feed_dict={X: Xbase}) 

我想要做的優化過的tensorflow多元高斯分佈的參數,如我上面的例子。我可以成功運行諸如sess.run(loss, feed_dict={X: Xbase})之類的命令,所以我已經正確實施了分發。當我嘗試運行優化操作時,出現奇怪的錯誤消息:

InvalidArgumentError: -1 is not between 0 and 3 
    [[Node: gradients_1/xDist_7/xProbs/Prod_grad/InvertPermutation = InvertPermutation[T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](gradients_1/xDist_7/xProbs/Prod_grad/concat)]] 

Caused by op 'gradients_1/xDist_7/xProbs/Prod_grad/InvertPermutation' 

我不明白。

如果我使用tf.contrib.distributions.MultivariateNormalFullCovariance而不是tf.contrib.distributions.MultivariateNormalDiag,則會得到相同的錯誤消息。如果scale_diag而不是loc是正在優化的變量,則不會收到錯誤。

回答

0

我仍在調查爲什麼這是失敗的,但對於短期修復,是否會做出以下更改?

xLogProbs = xDist.log_prob(X, name='xLogProbs') 
loss  = -tf.reduce_mean(xLogProbs, name='loss') 

注: - 有時基本上更精確,因爲它是從不小於數值精確這實際上是優選tf.log(xProbs)。 (所有tf.Distributions都是如此。)

+0

使用'.log_prob'工程。謝謝;我沒有多少使用'tf.Distributions'。 – ostrichgroomer

+0

我追查過這個問題。看起來,TF的發佈版本沒有正確處理'reduce_prod'的負索引。這個問題已經在master中解決了。無論如何,'log_prob'是首選。另請注意:由於原因太複雜,最好使用'tf.get_variable'而不是'tf.Variable'。祝你好運! – jvdillon

+0

另請參閱:https://github.com/tensorflow/tensorflow/issues/10766 – jvdillon