2017-05-03 69 views
6

閱讀https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py的功能average_gradients以下注釋提供:Note that this function provides a synchronization point across all towers.是功能average_gradients阻塞調用,並是什麼意思synchronization pointTensorflow CIFAR同步點

我認爲這是一個阻塞呼叫,因爲爲了計算梯度的平均值,每個梯度都必須單獨計算?但是,等待所有單個漸變計算的阻止代碼在哪裏?

回答

6

average_gradients本身不是阻塞功能。它可能是張量流操作的另一個函數,這仍然是一個同步點。是什麼讓它阻塞是因爲它使用參數tower_grads,這取決於在前面的for循環中創建的所有圖形。

基本上這裏發生的是創建訓練圖。首先,在for循環for i in xrange(FLAGS.num_gpus)中創建了幾個圖表「線程」。每個看起來像這樣:

計算損失 - >計算梯度 - >附加到tower_grads

每個的那些曲線圖「線程」被分配給一個不同的GPU通過with tf.device('/gpu:%d' % i)並且每一個可以運行彼此獨立的(並且稍後將並行運行)。現在下一次使用tower_grads時沒有設備規範,它會在主設備上創建一個圖繼續,將所有這些單獨的圖「線程」綁定到一個單獨的圖上。在運行average_gradients函數中的圖之前,Tensorflow將確保作爲創建tower_grads的一部分的每個圖形「線程」都已完成。因此稍後調用sess.run([train_op, loss])時,這將成爲圖的同步點。