2017-01-26 27 views
1

我剛剛開始學習cntk。但是,我有一個基本問題阻礙了我的進步。我有以下的測試通過:爲什麼不一致的形狀numpy vs cntk?

import numpy as np 
from cntk import input_variable, plus 

def test_simple(self): 

    x_input = np.asarray([[1, 2, 2]], dtype=np.int64) 
    assert (1, 3) == x_input.shape 

    y_input = np.asarray([[5, 3, 3]], dtype=np.int64) 
    assert (1, 3) == y_input.shape 

    x = input_variable(x_input.shape[1]) 
    assert (3,) == x.shape 

    y = input_variable(y_input.shape[1]) 
    assert (3,) == y.shape 

    x_plus_y = plus(x, y) 
    assert (3,) == x_plus_y.shape 

    res = x_plus_y.eval({x: x_input, y: y_input}) 

    assert 6 == res[0, 0, 0] 
    assert 5 == res[0, 0, 1] 
    assert 5 == res[0, 0, 2] 

據我所知,輸出的形狀爲(1,1,3)作爲第一和第二軸線是分批和分別缺省的動態軸。

但是,爲什麼我需要將輸入變量的形狀設置爲(3,)而不是(1,3)。使用(1,3)失敗。

爲什麼圖中輸入節點的形狀與用作該節點輸入的numpy數據之間存在不一致?

謝謝 水稻

回答

2

這是爲Function.forward解釋一點點的「論據」的說明。另一種描述是here。你們混淆的原因可能是CNTK做了一些「有用的」轉換。

如果您將輸入指定爲(1,3),則需要在沒有序列軸的小批次或(x,1,3)數組列表的情況下提供(1,3)數組列表如果是具有序列軸的小批次(其中x對於小批次中的每個序列而言可能不同)。同樣,如果您將輸入指定爲(3,),則需要提供(3,)向量列表或(x,3)向量列表。

混亂可能是由於沒有提供列表的情況而引起的。在那種情況下,CNTK在提供的張量的引導軸上迭代並且創建這些元素的列表,例如, (5,1,3)張量變成每個具有(1,3)形狀的一批5個元素。