2015-11-11 40 views
0

我做錯了什麼?sklearn BallTree給予意想不到的結果

我想要使用sklearn的BallTree來想出類似的集合,然後對可能缺少給定集合的項目產生一些建議。

import random 
from sklearn.neighbors import BallTree 
import numpy 

collections = [] # 10k sample collections of between 
        # 7 and 15 (of a possible 300...) items 

for sample in range(0, 10000): # build sample data 
    items = random.sample(range(1, 300), random.randint(7, 15)) 
    collections.append(items)  

darray = numpy.zeros((len(collections), max(map(len, collections)))) # 10k x 15 matrix 

for c_cnt, items in enumerate(collections): # populate matrix 
    for cnt, i in enumerate(sorted(items)): 
     darray[C_cnt][cnt] = i 

query = BallTree(darray).query(darray[0], k=15) 

nearest_neighbors = query[1][0] 

# test the results against the first item! 

all_sets = [set(darray[0]) & set(darray[item]) for item in nearest_neighbors] 
for item in all_sets: 
    print item # intersection of the neighbor 

我得到如下結果:

set([0.0, 130.0, 167.0, 290.0, 162.0, 144.0, 17.0, 214.0]) # Nearest neighbor is itself! Awesome! 
set([0.0]) # WTF? The second closest item shares only 1 item? 
set([0.0, 290.0]) 
set([0.0, 17.0]) 
set([0.0, 130.0]) 
set([0.0]) 
set([0.0]) 
set([0.0]) 
set([0.0]) 
set([0.0]) 
set([0.0]) 
set([0.0]) 
set([0.0, 162.0]) 
set([0.0, 144.0, 162.0]) # uhh okay, i would expect this to be higher up 
set([0.0, 144.0, 17.0]) 

我觀察到更高的建議項目往往具有非零值的長度相同,因爲我試圖比較陣列。是否有一些準備我可以用我的數據來解決這個問題?

回答

2

默認情況下,BallTree會計算向量之間的歐幾里德距離,因此它不適合您想要的計算類型。

舉一個簡單的例子,假設您有以下兩類:

​​3210

當您將它們轉換爲載體內darray,你在上面做了,他們成爲了這個:

darray[0] = [1, 3, 0] 
darray[1] = [1, 2, 3] 

的這些之間的歐幾里德距離並不反映組中類似條目的數量,這就是爲什麼結果不符合您的預期。

而不是歐幾里得距離,你正在尋找的距離度量可能是Jaccard distance,它衡量集合之間的相似度。 BallTree將其用於布爾的布爾表示;即,對於上述數據載體將成爲

​​

其中第一項表示如果1是在該組中,第二項表示如果2是在該組,依此類推。這是「單熱編碼」的一個版本。

爲您提供,你可以計算的結果這樣的樣本數據:

import numpy as np 
from sklearn.neighbors import BallTree 
from sklearn.feature_extraction import DictVectorizer 

# for replicability 
np.random.seed(0) 

# Compute the collections using a more efficient method 
collections = [np.random.choice(300, replace=False, 
           size=np.random.randint(7, 15)) 
       for _ in range(10000)] 

# Use DictVectorizer to compute binary representation of collections 
dicts = [dict(zip(c, np.ones_like(c))) for c in collections] 
darray = DictVectorizer(sparse=False, dtype=bool).fit_transform(dicts) 

# Compute 15 nearest neighbors for the first collection 
dist, ind = BallTree(darray, metric='jaccard').query(darray[0], k=15) 
for i in ind[0]: 
    print(set(collections[0]) & set(collections[i])) 

我得到如下結果:

{225, 226, 261, 166, 296, 52, 150, 246, 215, 221, 223} 
{52, 261, 221, 215} 
{225, 226, 166, 150} 
{223, 150, 215} 
{225, 261, 166, 221} 
{226, 261, 223} 
{261, 150, 221} 
{223, 52, 166, 215} 
{296, 226, 166, 223} 
{296, 221, 150} 
{223, 52, 215} 
{52, 261, 246} 
{296, 225, 52} 
{296, 225, 221} 
{225, 150, 223} 

注意,Jaccard相似不是簡單的規模交叉點,但這個大小通過聯合的大小標準化。僅路口的大小不具有距離度量的屬性,因此不能直接用BallTree計算。

編輯:我應該補充說,如果你在集合中有很多條目,這個方法就變得站不住了,因爲布爾編碼矩陣變得太大了。使用Jaccard距離計算非常高維的鄰居搜索的最好方法可能是通過局部敏感散列,儘管我不知道適用於此問題的易於使用的Python實現。