1
我剛剛開始學習cntk。但是,我有一個基本問題阻礙了我的進步。我有以下的測試通過:爲什麼不一致的形狀numpy vs cntk?
import numpy as np
from cntk import input_variable, plus
def test_simple(self):
x_input = np.asarray([[1, 2, 2]], dtype=np.int64)
assert (1, 3) == x_input.shape
y_input = np.asarray([[5, 3, 3]], dtype=np.int64)
assert (1, 3) == y_input.shape
x = input_variable(x_input.shape[1])
assert (3,) == x.shape
y = input_variable(y_input.shape[1])
assert (3,) == y.shape
x_plus_y = plus(x, y)
assert (3,) == x_plus_y.shape
res = x_plus_y.eval({x: x_input, y: y_input})
assert 6 == res[0, 0, 0]
assert 5 == res[0, 0, 1]
assert 5 == res[0, 0, 2]
據我所知,輸出的形狀爲(1,1,3)作爲第一和第二軸線是分批和分別缺省的動態軸。
但是,爲什麼我需要將輸入變量的形狀設置爲(3,)而不是(1,3)。使用(1,3)失敗。
爲什麼圖中輸入節點的形狀與用作該節點輸入的numpy數據之間存在不一致?
謝謝 水稻