2016-04-18 20 views
0

我使用keras來構建推薦者模型。由於項目集很大,我想計算Hits @ N度量作爲精度的度量。也就是說,如果觀察到的項目在前N個預測中,則它被視爲相關推薦。計算點擊量在Theano中的度量標準

我能夠使用numpy在N函數中創建匹配。但是,當我試圖將它轉換爲keras的自定義損失函數時,我遇到了張量問題。具體來說,枚舉張量是不同的。當我研究語法來找到相同的東西時,我開始質疑整個方法。這是sl and而緩慢的,反映了我一般的蟒蛇熟悉。

def hits_at(y_true, y_pred): #numpy version 
    a=y_pred.argsort(axis=1) #ascending, sort by row, return index 
    a = np.fliplr(a) #reverse to get descending 
    a = a[:,0:10] #return only the first 10 columns of each row 
    Ybool = [] #initialze 2D arrray 
    for t, idx in enumerate(a): 
     ybool = np.zeros(num_items +1) #zero fill; 0 index is reserved 
     ybool[idx] = 1 #flip the recommended item from 0 to 1 
     Ybool.append(ybool) 
    A = map(lambda t: list(t), Ybool) 
    right_sum = (A * y_true).max(axis=1) #element-wise multiplication, then find the max 
    right_sum = right_sum.sum() #how many times did we score a hit? 
    return right_sum/len(y_true) #fraction of observations where we scored a hit 

我應該如何以更緊湊和張量友好的方式來解決這個問題?

更新: 我能夠得到頂端1工作的版本。我基於GRU4Rec描述 def custom_objective(y_true, y_pred): y_pred_idx_sort = T.argsort(-y_pred, axis=1)[:,0] #returns the first element, which is the index of the row with the largest value y_act_idx = T.argmax(y_true, axis=1)#returns an array of indexes with the top value return T.cast(-T.mean(T.nnet.sigmoid((T.eq(y_pred_idx_sort,y_act_idx)))), theano.config.floatX)

我只是要將前1個預測的數組與元素方面的實際數組進行比較。 Theano有一個eq()函數來做到這一點。

回答

0

與N無關,您的損失函數的可能值的數量是有限的。因此它不能以明智的張量方式進行區分,並且不能在Keras/Theano中將其用作損失函數。您可能會嘗試使用前N名球員的theano日誌丟失。

UPDATE:

在Keras - 你可以寫你自己的損失函數。他們有一個格式的聲明:

def loss_function(y_pred, y_true): 

y_true都和y_pred是numpy的陣列,所以你可能伊斯利獲得矢量v爲1時給出的例子是在頂部500,否則爲0。然後你可以將它轉換成theano tensor常量向量並以某種方式應用它:

return theano.tensor.net.binary_crossentropy(y_pred * v, y_true * v) 

這應該可以正常工作。

更新2:

登錄損失是一樣的東西是什麼binary_crossentropy。

+0

我目前使用的是categorical_crossentropy;我不熟悉「頂級N的日誌丟失」。是以某種方式擴展http://deeplearning.net/tutorial/logreg.html#the-model的問題嗎? –

+0

我更新了我的評論 –

+0

Hrm,似乎y_pred和y_true實際上是張量,而不是numpy數組。這使得它稍微複雜一點,因爲找到內置值並迭代它們比爲numpy陣列做同樣的事情更復雜。 –