2016-04-23 65 views
12

我使用dynamic_rnn處理MNIST數據:在張量流中獲取dynamic_rnn的最後一個輸出嗎?

# LSTM Cell 
lstm = rnn_cell.LSTMCell(num_units=200, 
         forget_bias=1.0, 
         initializer=tf.random_normal) 

# Initial state 
istate = lstm.zero_state(batch_size, "float") 

# Get lstm cell output 
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate) 

# Output at last time point T 
output_at_T = output[:, 27, :] 

全碼:http://pastebin.com/bhf9MgMe

輸入到LSTM是(batch_size, sequence_length, input_size)

結果的output_at_T尺寸是(batch_size, sequence_length, num_units)其中num_units=200

我需要沿sequence_length尺寸獲得最後一個輸出。在上面的代碼中,這被硬編碼爲27。但是,我不知道提前sequence_length,因爲它可以在我的應用程序中按批次更改。

我嘗試:

output_at_T = output[:, -1, :] 

但它說負索引還沒有實現,我試圖使用佔位符變量以及一個常數(在其中我可以理想地喂sequence_length針對特定批次) ;既沒有工作。

任何方式在tensorflow atm中實現這樣的東西?

+1

你的序列長度是否相等? – danijar

+0

[獲取TensorFlow中動態\ _rnn的最後一個輸出]的可能重複(http://stackoverflow.com/questions/41273361/get-the-last-output-of-a-dynamic-rnn-in-tensorflow) – Alex

回答

0

您應該能夠使用tf.shape(output)訪問您的output張量的形狀。 tf.shape()函數將返回一個包含output張量大小的1d張量。在你的榜樣,這將是(batch_size, sequence_length, num_units)

然後,您應該能夠提取的output_at_T值作爲output[:, tf.shape(output)[1], :]

0

有一個在TensorFlow tf.shape一個功能,可以讓你獲得形狀的象徵性解釋,而不是None是由output._shape[1]返回。在獲取最後一個索引後,您可以使用tf.nn.embedding_lookup進行查找,這是特別建議在要提取的數據較高時推薦的,因爲這可以並行查找32 by default

# Let's first fetch the last index of seq length 
# last_index would have a scalar value 
last_index = tf.shape(output)[1] 
# Then let's reshape the output to [sequence_length,batch_size,num_units] 
# for convenience 
output_rs = tf.transpose(output,[1,0,2]) 
# Last state of all batches 
last_state = tf.nn.embedding_lookup(output_rs,last_index) 

這應該有效。

只是爲了澄清@Benoit Steiner所說的話。他的解決方案將無法正常工作,因爲tf.shape會返回形狀值的符號解釋,因此不能用於切片張量,即直接索引

2

我是新來的Stackoverflow,不能評論,所以我正在寫這個新的答案。 @VM_AI,最後一個索引是tf.shape(output)[1] - 1。 所以,重用你的答案:

# Let's first fetch the last index of seq length 
# last_index would have a scalar value 
last_index = tf.shape(output)[1] - 1 
# Then let's reshape the output to [sequence_length,batch_size,num_units] 
# for convenience 
output_rs = tf.transpose(output,[1,0,2]) 
# Last state of all batches 
last_state = tf.nn.embedding_lookup(output_rs,last_index) 

這對我的作品。

+0

最後的狀態已經在tensorflow中給你1.2 –

2
output[:, -1, :] 

現在與Tensorflow 1.x一起使用!!

+0

這隻適用於'output'中每個元素的'sequence_length'是相同的。 – Alex

+0

是的你是對的! –

2

這就是gather_nd的用途!

def extract_axis_1(data, ind): 
    """ 
    Get specified elements along the first axis of tensor. 
    :param data: Tensorflow tensor that will be subsetted. 
    :param ind: Indices to take (one for each element along axis 0 of data). 
    :return: Subsetted tensor. 
    """ 

    batch_range = tf.range(tf.shape(data)[0]) 
    indices = tf.stack([batch_range, ind], axis=1) 
    res = tf.gather_nd(data, indices) 

    return res 

在你的情況(假定sequence_length是1- d張量與每個軸線0元件的長度):

output = extract_axis_1(output, sequence_length - 1) 

現在輸出是尺寸[batch_size, num_cells]的張量。

+1

有沒有必要這樣做,dynamic_rnn作爲輸出的最後一個狀態。檢查我的答案。 – Escachator

+0

@Escachator只有在序列長度相同(批次內)時纔有效。 OP說:「我事先不知道'sequence_length',因爲它可以在我的應用程序中從一個批次變爲另一批次,所以這可能是一個公平的假設 - 我不知道。即使批次內的序列長度不同(並且填充了數據),我的答案仍然有效。 – Alex

+0

如果您知道(或得到)'ind',您知道批次中每個元素的序列長度。然後你可以在dynamic_rnn的'sequence_length'參數中輸入它,一旦它到達批處理的每個元素,它就會停止計算。然後你只需要final_state.h。我可能應該添加到我的答案。 – Escachator

5

您是否注意到dynamic_rnn有兩個輸出?

  1. 輸出1,讓我們稱之爲小時,在每個時間步驟中的所有輸出(即H_1,H_2等),
  2. 輸出2,final_state,有兩個要素:cell_state,最後輸出批次的每個元素(只要您輸入序列長度爲dynamic_rnn)。

所以從:

h, final_state= tf.dynamic_rnn(..., sequence_length=[batch_size_vector], ...) 

在批次中的每個元件的最後一個狀態是:

final_state.h 

注意它包括的情況下,當序列的長度是對於每個不同的批處理的元素,因爲我們使用sequence_length參數。

+0

'final_state.h'返回最後*輸出*或最後*隱藏狀態*? – bluesummers

+0

@bluesummers final_state.h是批處理中每個元素(即最終輸出)和final_state.c對單元格 – Escachator

+0

感謝狀態的最終激活。要明確,'final_state.h'將具有[batch_size,lstm_units]的維度? – bluesummers

相關問題