2016-03-01 74 views
2

我是Tensorflow的初學者。我從「入門」頁面 中選擇了一個適合行的示例,並且做出了我認爲對其進行了幾乎微不足道的修改,但完全失敗。 我不明白。初學者 - 解決一個簡單的凸優化

在修改後的版本中,數組b_data是兩個已知高斯的和,其權重未知。 嘗試解決這些重量。這是一個二次問題,可以作爲線性系統來解決 。

儘管真實權重爲0.4,0.2,但梯度下降給出w [0]爲負數, 和w [1]爲正數。

這就是問題:雖然問題是凸的(二次均勻),但tensorflow並沒有找到正確的答案。

我想我一定是做了錯誤的損失功能? 事實上,我認爲用損失

tf.reduce_sum(tf.square(b - b_data))

是我想要的(對應於平方2範數|| b - b_data ||^2),但是 嘗試這種更是雪上加霜,將導致在NaNs。

import numpy as np 
import matplotlib.pyplot as pl 
import tensorflow as tf 

RES = 200 
CEN = [0.2, 0.3, 0.6] 
SD = [0.1, 0.15, 0.07] 
X = np.linspace(0., 1., RES).astype(np.float32) 
G0 = np.exp(- np.power(X - CEN[0], 2)/SD[0]) 
G1 = np.exp(- np.power(X - CEN[1], 2)/SD[1]) 
B = np.vstack([G0,G1]) 
B = B.T 

b_data = 0.4*G0 + 0.2*G1 

# check numpy answer 
w_ = np.linalg.lstsq(B,b_data) 
print('numpy answer',w_[0])  # correct: 0.4, 0.2 

w = tf.Variable(tf.random_uniform([2,1], 0., 0.5)) 
b = tf.matmul(B,w) 

loss = tf.reduce_mean(tf.square(b - b_data)) 
#loss = tf.reduce_sum(tf.square(b - b_data)) 
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) 
train = optimizer.minimize(loss) 

init = tf.initialize_all_variables() 

sess = tf.Session() 
sess.run(init) 

for step in xrange(8001): 
    sess.run(train) 
    if step % 100 == 0: 
     print(step, sess.run(loss), sess.run(tf.transpose(w))) 

print('w', sess.run(w)) 
bfit = sess.run(w[0,0])*G0 + sess.run(w[1,0])*G1 

pl.clf() 
pl.plot(G0,'g-') 
pl.plot(G1,'b-') 
pl.plot(b_data,'r-') 
pl.plot(bfit,'k-') 
pl.show() 
+0

你得到的錯誤究竟是什麼?或者它只是打印(步驟,sess.run(損失),sess.run(tf.transpose(w))) 返回NaN? –

+0

我補充說明。問題是張量流程代碼找不到正確的解決方案。 – bullwinkle

+0

我嘗試打印漸變,它快速到零,所以我認爲這是一個局部最小值。 '[wgrad,_] = optimizer.compute_gradients(loss,[w])[0]'但這個問題似乎是二次的,所以這很奇怪 –

回答

0

你就沒有優化配置適當的成本函數,但一個較大的減法播放你的陣列b_datab

> print(tf.square(b - b_data)) 
Tensor("Square:0", shape=(200, 200), dtype=float32) 

>print(tf.square(b[:, 0] - b_data)) 
Tensor("Square:0", shape=(200,), dtype=float32) 

這是通過在tensorflow實施(see for instance this issue)反直覺的廣播引起的。

如果用loss = tf.reduce_mean(tf.square(b[:, 0] - b_data))代替損失,優化成功並返回結果。