我們可以在索引方面稍微聰明些,節省大約4倍的成本。
首先讓建立正確的形狀的一些數據:
seed = np.random.randint(0, 100, (200,206))
data = np.random.randint(0, 100, (4e5,206))
seed[:, 0] = np.arange(200)
data[:, 0] = np.random.randint(0, 200, 4e5)
diam = np.empty(200)
原答覆的時間:
%%timeit
for i in range(200):
diam[i] = spd.cdist(seed[np.newaxis, i, 1:], data[data[:, 0]==i][:,1:]).max()
1 loops, best of 3: 1.35 s per loop
moarningsun的回答是:
%%timeit
seed_repeated = seed[data[:,0]]
dist_to_center = np.sqrt(np.sum((data[:,1:]-seed_repeated[:,1:])**2, axis=1))
diam = np.zeros(len(seed))
np.maximum.at(diam, data[:,0], dist_to_center)
1 loops, best of 3: 1.33 s per loop
Divakar的回答是:
%%timeit
data_sorted = data[data[:, 0].argsort()]
seed_ext = np.repeat(seed,np.bincount(data_sorted[:,0]),axis=0)
dists = np.sqrt(((data_sorted[:,1:] - seed_ext[:,1:])**2).sum(1))
shift_idx = np.append(0,np.nonzero(np.diff(data_sorted[:,0]))[0]+1)
diam_out = np.maximum.reduceat(dists,shift_idx)
1 loops, best of 3: 1.65 s per loop
正如我們所看到的,除了更大的內存佔用之外,還沒有真正獲得任何矢量化解決方案。爲了避免這種情況,我們需要返回到原來的答案,這是真的做這些事情的正確方法,而是試圖減少索引量:
%%timeit
idx = data[:,0].argsort()
bins = np.bincount(data[:,0])
counter = 0
for i in range(200):
data_slice = idx[counter: counter+bins[i]]
diam[i] = spd.cdist(seed[None, i, 1:], data[data_slice, 1:]).max()
counter += bins[i]
1 loops, best of 3: 281 ms per loop
仔細檢查答案:
np.allclose(diam, dam_out)
True
這是假設python循環不好的問題。他們往往是,但不是在所有情況下。
這實際上是相當合理的代碼。你的for循環相對於'cdist'內完成的計算量相對較小。由於'cdist'是一個相當優化的速度,收益不會很大。 – Daniel
@Ophion - 雖然可以避免重複的線性搜索data [:, 0] == i,從O(n ** 2)到O(n log(n))甚至O(n )。 – 2015-11-06 20:19:23
@moarningsun是的,但是可能的和可用的是兩個不同的東西,特別是考慮到O(n * m)而不是O(n^2)和n << m。到目前爲止,沒有任何解決方案比OP更快,並且所有解決方案都有更多的內存開銷。 – Daniel