2016-08-18 98 views
3

我訓練了一個遞歸神經網絡(LSTM)並保存了權重和metagraph。當我檢索metagraph進行預測時,只要序列長度與訓練過程中的序列長度相同,一切都可以正常工作。Tensorflow:檢索metagraph時修改佔位符的形狀

LSTM的好處之一是輸入的序列長度可以變化(例如,如果輸入是形成句子的字母,則句子的長度可以變化)。

從metagraph中檢索圖形時,如何更改輸入的序列長度?

更多細節與代碼:

在培訓過程中,我使用佔位符xy養活數據。對於預測,我檢索這些佔位符,但無法設法更改其形狀(從[None, previous_sequence_length=100, n_input][None, new_sequence_length=50, n_input])。

在文件model.py,定義體系結構和佔位符:

self.x = tf.placeholder("float32", [None, self.n_steps, self.n_input], name='x_input') 
self.y = tf.placeholder("float32", [None, self.n_classes], name='y_labels') 
tf.add_to_collection('x', self.x) 
tf.add_to_collection('y', self.y) 
... 

def build_model(self): 
    #using the placeholder self.x to build the model 
    ... 
    tf.split(0, self.n_input, self.x) # split input for RNN cell 
    ... 

在文件prediction.py,我檢索預測元圖:

with tf.Session() as sess: 
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir) 
    new_saver = tf.train.import_meta_graph(latest_checkpoint + '.meta') 
    new_saver.restore(sess, latest_checkpoint) 
    x = tf.get_collection('x')[0] 
    y = tf.get_collection('y')[0] 
    ... 
    sess.run(..., feed_dict={x: batch_x}) 

這裏是我的錯誤:

ValueError: Cannot feed value of shape (128, 50, 2) for Tensor u'placeholders/x_input:0', which has shape '(?, 100, 2)' 

注:我設法解決這個問題n 不使用metagraph,而是重新從頭開始重建模型並僅加載保存的權重(而不是元數據圖)。

編輯:與None更換self.n_steps和修改tf.split(0, self.n_input, self.x)tf.split(0, self.x.get_shape()[1], self.x)時,我得到了以下錯誤:TypeError: Expected int for argument 'num_split' not Dimension(None).

+0

你通常不能在事實後改變張量的形狀。但是,您可以做的一件事是*不*修正訓練過程中所有維度的形狀,但不指定它們。您提供的張量的尺寸必須與佔位符的形狀兼容,但您不必強制首先指定所有佔位符尺寸。在這裏,嘗試將「無」替換爲self.n_steps。 –

+0

我在發佈問題之前就嘗試過這樣做,但在創建模型期間的某個時候,我有'tf.split(0,self.n_input,self.x)'。當我不知道/修正'self.n_input'時,我把'self.x.get_shape()[1]'('tf.split(0,self.x.get_shape()[1],self) x)的')。但是,我得到以下錯誤:'TypeError:參數'num_split'的預期int不是Dimension(無).'。 – BiBi

回答

2

當你定義varible,我建議你把它寫如下

[None, None, n_input] 

代替:

[None, new_sequence_length=50, n_input] 

它適用於我的情況。我希望它有幫助

+0

我試過了(參考初始文章中的評論),但由於'tf.split'函數在輸入時需要使用分割數量,所以在此解決方案中爲「無」,因此它不起作用。 – BiBi