我做錯了什麼?sklearn BallTree給予意想不到的結果
我想要使用sklearn的BallTree來想出類似的集合,然後對可能缺少給定集合的項目產生一些建議。
import random
from sklearn.neighbors import BallTree
import numpy
collections = [] # 10k sample collections of between
# 7 and 15 (of a possible 300...) items
for sample in range(0, 10000): # build sample data
items = random.sample(range(1, 300), random.randint(7, 15))
collections.append(items)
darray = numpy.zeros((len(collections), max(map(len, collections)))) # 10k x 15 matrix
for c_cnt, items in enumerate(collections): # populate matrix
for cnt, i in enumerate(sorted(items)):
darray[C_cnt][cnt] = i
query = BallTree(darray).query(darray[0], k=15)
nearest_neighbors = query[1][0]
# test the results against the first item!
all_sets = [set(darray[0]) & set(darray[item]) for item in nearest_neighbors]
for item in all_sets:
print item # intersection of the neighbor
我得到如下結果:
set([0.0, 130.0, 167.0, 290.0, 162.0, 144.0, 17.0, 214.0]) # Nearest neighbor is itself! Awesome!
set([0.0]) # WTF? The second closest item shares only 1 item?
set([0.0, 290.0])
set([0.0, 17.0])
set([0.0, 130.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0, 162.0])
set([0.0, 144.0, 162.0]) # uhh okay, i would expect this to be higher up
set([0.0, 144.0, 17.0])
我觀察到更高的建議項目往往具有非零值的長度相同,因爲我試圖比較陣列。是否有一些準備我可以用我的數據來解決這個問題?