2016-04-01 27 views
0

我想在張量流中實現一個簡單的線性迴歸(最終將其擴展到更高級的模型)。我目前的代碼如下:張量流線性迴歸非常緩慢

def linear_regression(data, labels): 
    # Setup placeholders and variables 
    num_datapoints = data.shape[0] 
    num_features = data.shape[1] 
    x = tf.placeholder(tf.float32, [None, num_features]) 
    y_ = tf.placeholder(tf.float32, [None]) 
    coeffs = tf.Variable(tf.random_normal(shape=[num_features, 1])) 
    bias = tf.Variable(tf.random_normal(shape=[1])) 

    # Prediction 
    y = tf.matmul(x, coeffs) + bias 

    # Cost function 
    cost = tf.reduce_sum(tf.pow(y-y_, 2))/(2.*num_datapoints) 


    # Optimizer 
    NUM_STEPS = 500 
    optimizer = tf.train.AdamOptimizer() 
    train_step = optimizer.minimize(lasso_cost) 

    # Fit the model 
    init = tf.initialize_all_variables() 
    cost_history = np.zeros(NUM_STEPS) 
    sess = tf.Session() 
    sess.run(init) 

    for i in range(NUM_STEPS): 
     if i % 100 == 0: 
      print 'Step:', i 
     for xi, yi in zip(data, labels): 
      sess.run(train_step, feed_dict={x: np.expand_dims(xi, axis=0), 
       y_: np.expand_dims(yi, axis=0)}) 
      cost_history[i] = sess.run(lasso_cost, feed_dict={x: data, 
       y_:labels}) 

    return sess.run(coeffs), cost_history 

該代碼工作,並找到正確的係數。但是,它非常緩慢。在我的MacBook Pro上,只需幾分鐘即可爲具有1000個數據點和10個功能的數據集運行幾個訓練時期。由於我運行的是OSX,因此我沒有GPU加速功能,這可以解釋一些緩慢的情況,但我認爲它可能會比這更快。我已經嘗試過不同的優化器,但性能非常相似。

有沒有一些明顯的方法來加速這段代碼?否則,感覺像tensorflow對於這些類型的問題幾乎沒有用處。

回答

4

它很慢,因爲你訓練網絡的點需要NUM_STEPS * num_datapoints迭代(這導致5萬個週期)。

所有你真正需要訓練你的網絡是

for i in range(NUM_STEPS): 
    sess.run(train_step, feed_dict={x: data, y_:labels}) 

這僅舉幾秒鐘。

+0

感謝您的評論。這種改變當然會加快代碼的速度,但如果我這樣做,它似乎會收斂到完全隨機的(和錯誤的解決方案)。其實我也有意要問這個問題。我最初的代碼是一次性訓練所有數據,但沒有找到正確的解決方案。我在網上發現了幾個不同的示例代碼,並且他們都在每個數據點上進行訓練。出於某種原因,這似乎改善了最小化的收斂性?我不明白爲什麼。 – user3468216

+0

您使用的默認學習率可能相當高,因此會導致不收斂的結果。因此,當您創建優化器時,請明確選擇學習速率,例如AdamOptimizer(learning_rate = 0.0001)。 –

+0

那麼,我已經用不同的優化器和學習速率進行了很多實驗,並且我沒有發現任何組合在同時訓練整個數據集時給出了正確的結果(我總是得到一個與係數完全無關的模型投放)。也許我錯過了一些顯而易見的東西,但是如果有人能夠或者可以指出我的線性迴歸的tensorflow實現是有效的,並且效率很高,我會非常感激。 – user3468216