2016-10-30 94 views
1

我目前在tensorflow中實現http://www.aclweb.org/anthology/P15-1061有效計算Tensorflow中的成對排序損失函數

我已經實現成對排名損失函數(紙張的第2.5節)如下:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index) 
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size]) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

我不得不使用tf.gather而非tf.gather_nd因爲後者尚未與梯度實現血統。我還必須將所有的指數轉換成平坦矩陣。

如果tf.gather_nd與梯度下降來實現,我的代碼將是如下:

s_theta_y = tf.gather_nd(s_theta, y_t_index) 
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

s_theta是爲每一類標籤的計算出的分數,如紙。 y_true_index包含真實類的索引,以便計算s_theta_y。 y_neg_index是所有否定類的索引,其維數是#class-1或#class是關係被歸類爲其他。

但是,有幾個句子被歸類爲其他,因此,s_theta_y 不存在,我們不應該將其計算在內。爲了處理這種情況,我有一個常數因子0,它取消了這個詞,並且對於負類具有相同的維矢量,我只是複製索引的一個隨機值,因爲最後我們只對所有負面類別中的最大值(而不是索引)。

是否有更有效的方法來計算損失函數中的這些條款?我有一個印象,使用tf.gather與很多重塑很慢

回答

1

當然,它聽起來像gather_nd是你想要的,但直到漸變實施那裏,我會毫不猶豫地使用你的reshape()解決方案,因爲reshape()實際上是免費的。

C++ implementation of the reshape() op看起來像做了很多工作,但它只是快速檢查形狀信息。 「工作」發生在90行的CopyFrom上,這聽起來可能很昂貴,但實際上只是一個指針拷貝(CopyFrom調用CopyFromInternal拷貝指針)。

這樣做很有意義:底層緩衝區只是row-major order中的一個數字平面陣列,並且該排序不依賴於形狀信息。出於同樣的原因,像tf.transpose()需要複製一般。