2017-10-20 72 views
0

我們已經培訓了一個tf-seq2seq模型來回答問題。主要框架是google/seq2seq。我們使用雙向RNN(GRU編碼器/解碼器128單元),增加了軟關注機制。如何加速tensorflow RNN推理時間

我們將最大長度限制爲100個字。它大多隻產生10〜20個單詞。

對於模型的推斷,我們嘗試兩種情況:

  1. 正常(貪心算法)。其推斷時間約爲40ms〜100ms
  2. 光束搜索。我們嘗試使用波束寬度5,其推斷時間約爲400ms〜1000ms。

所以,我們想嘗試使用波束寬度3,它的時間可能會減少,但它也可能影響最終效果。

那麼有什麼建議可以減少我們案件的推理時間嗎?謝謝。

+0

RNN向前傳播中最主要的限制因素之一是詞彙大小。 –

+0

感謝您的評論。我們的目標詞彙量大約是8000.如果我們設定了頻率限制,它可以減少到5000。我們可能會嘗試使用後者的小詞彙量。 –

+0

根據你的評論,我會建議你將算法放到一個更小的網絡中。 – lerner

回答

0

如果你的模型可以被修改並重新訓練的,所以據我所知,您可以:

  1. 網絡提煉到一個較小的一個,這被稱爲knowledge distillation

  2. 或者您可以減少詞彙大小,例如將其減半,然後使用複製機制在輸入序列上分配輸入。每個批次都需要一個大小不一的詞彙表,這些詞彙表應該屬於小詞彙表,即pointer generator

  3. 我認爲Tensorflow Lite是你在找什麼。

+0

謝謝!我們的應用程序需要爲用戶查詢輸出一個實時響應。所以我們只需要將一個查詢語句輸入到模型中,然後讓模型輸出一個響應。我們的batch_size = 1用於推斷。 –