2016-08-05 72 views
3

我是一個AI和tensorflow的總和,所以請原諒,如果這是一個愚蠢的問題。 我已經使用基於本教程的腳本培養了tensorflow網絡:Tensorflow'功能'格式

https://www.tensorflow.org/versions/r0.10/tutorials/wide_and_deep/index.html

我認爲培訓是確定的。 現在我whant運行這個方法做一個預測單個輸入:

tf.contrib.learn.DNNClassifier.predict_proba(x=x) 

但我不能找到如何打造的「X」參數的任何文件... 我tryed:

x = {k: tf.SparseTensor(indices=[[0, 0]], values=[d_data[k]], shape=[1, 1]) for k in COLUMNS} 

其中: d_data是包含大約150個鍵/值對的字典。 COLUMNS是一個列表,包含所有需要的密鑰。 同樣的設置用於訓練網絡。

但得到的錯誤:

AttributeError: 'dict' object has no attribute 'dtype' 

所以... X不應該是一個「字典」 ...但它應該是呢? 任何人都可以給我一些方向嗎?

非常感謝。

回答

2

BaseEstimator類有更好的documentation

x: Matrix of shape [n_samples, n_features...]. Can be iterator that returns arrays of features. The training input samples for fitting the model. If set, `input_fn` must be `None`. 

我會考慮在這裏修復文檔。感謝您指出。

0

我得到了同樣的錯誤,但我認爲這是因爲我們正在使用tensorflow的舊版本(我在0.8.0),現在fit方法可以採用不同的輸入類型'input_fn',我認爲它可以採用字典的形式,看到here

def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, 
     monitors=None, max_steps=None): 

在我目前的版本此功能不會有「input_fn」,因此爲什麼它是強制性的輸入張量矩陣對象爲x。

您是否設法在此期間找到解決方案?