2017-08-23 41 views

回答

1

NStepLSTM假定輸入是可能具有不同長度的序列的小批次。輸入是這些序列的列表。每個序列都由形狀爲(T, D)的變量表示,其中T是序列的長度,D是序列中每個項目的維數(如果您正在處理文本數據,則D可以是嵌入圖層的維度)。

然後,NStepLSTM.__call__返回的三個元組:最終隱藏狀態,最終細胞狀態,且輸出序列以類似的格式作爲輸入序列(變量的列表)。您可以將其與其他功能或鏈接結合使用。例如,您可以將輸出序列中的每個變量傳遞給某個損失函數以獲得損失。

+0

你的意思是輸入是minibatch列表嗎?假設我有一篇不同長度的文章。每個句子都是變量形狀(T,D)。 D是詞嵌入維度,所以輸入是整篇文章?那會是內存昂貴的!因爲如果我們輸入所有的序列到GPU,這會耗盡所有的GPU內存? – machen

+0

您的意思是返回第3個值也是返回所有時間步驟輸出,並且每個時間步驟中的所有序列都將返回? – machen

+0

如果整個序列太長而無法放入內存,則必須將句子拆分爲幾個片斷,然後進行截斷的BPTT(作爲Chainer的官方語言建模示例(ptb))。返回的元組的第三個元素與'NStepLSTM'的輸入具有相同的長度。 –

相關問題