0
我有以下代碼numpy apply_over_axes強制keepdims = True?
import numpy as np
import sys
def barycenter(arr, axis=0) :
bc = np.mean(arr, axis, keepdims=False)
print("src shape:", arr.shape, ", **** trg shape:", bc.shape, "****")
sys.stdout.flush()
return bc
a = np.array([[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4]],
[[0.4, 0.4, 0.4], [0.7, 0.6, 0.8]]], np.float)
e = barycenter(a, 2)
print("direct application =", e, "**** (trg shape =", e.shape, ") ****\n")
f = np.apply_over_axes(barycenter, a, 2)
print("application through apply_over_axes =", f, "**** (trg shape =", f.shape, ") ****\n")
產生以下輸出
src shape: (2, 2, 3) , **** trg shape: (2, 2) ****
direct application = [[ 0.2 0.3]
[ 0.4 0.7]] **** (trg shape = (2, 2)) ****
src shape: (2, 2, 3) , **** trg shape: (2, 2) ****
application through apply_over_axes = [[[ 0.2]
[ 0.3]]
[[ 0.4]
[ 0.7]]] **** (trg shape = (2, 2, 1)) ****
因此函數barycenter
的返回值是從什麼與apply_over_axes(barycenter, ...
獲得的不同。
這是爲什麼?
準確。我從文檔中錯過了這一部分。然後我又去了,發現它......我正要消除這個問題。 –