2017-06-19 177 views
0

我一直在嘗試使用谷歌的基於RNN seq2seq model.基於谷歌TensorFlow seq2seq模型墜毀在訓練

我一直在訓練文摘模型,並在大約大小1GB的文本數據正在餵養。該模型很快填滿了我的整個RAM(8GB),開始填滿交換內存(進一步8GB)並崩潰後,我必須做一個硬關機。

我LSTM網絡的配置如下:

model: AttentionSeq2Seq 
model_params: 
    attention.class: seq2seq.decoders.attention.AttentionLayerDot 
    attention.params: 
    num_units: 128 
    bridge.class: seq2seq.models.bridges.ZeroBridge 
    embedding.dim: 128 
    encoder.class: seq2seq.encoders.BidirectionalRNNEncoder 
    encoder.params: 
    rnn_cell: 
     cell_class: GRUCell 
     cell_params: 
     num_units: 128 
     dropout_input_keep_prob: 0.8 
     dropout_output_keep_prob: 1.0 
     num_layers: 1 
    decoder.class: seq2seq.decoders.AttentionDecoder 
    decoder.params: 
    rnn_cell: 
     cell_class: GRUCell 
     cell_params: 
     num_units: 128 
     dropout_input_keep_prob: 0.8 
     dropout_output_keep_prob: 1.0 
     num_layers: 1 
    optimizer.name: Adam 
    optimizer.params: 
    epsilon: 0.0000008 
    optimizer.learning_rate: 0.0001 
    source.max_seq_len: 50 
    source.reverse: false 
    target.max_seq_len: 50 

我試圖減少來自32個批次大小爲16,但它仍然沒有幫助。爲了防止我的模型佔用整個RAM並導致崩潰,我應該做出哪些具體的更改? (如減小數據大小,減少堆疊LSTM單元的數量,進一步減少批量大小等)

我的系統運行Python 2.7x,TensorFlow 1.1.0版和CUDA 8.0。該系統具有4GB內存的Nvidia Geforce GTX-1050Ti(768個CUDA內核),該系統擁有8GB內存和另外8GB交換內存。

回答

0

你模特看起來很小。火車數據中唯一的一件大事。請檢查以確保您的get_batch()函數沒有錯誤。有可能每個批次實際上都加載了整個數據集以供培訓,以防出現問題。

爲了快速證明這一點,只需將您的訓練數據大小減小到非常小(例如當前大小的1/10)並查看是否有幫助。請注意,它不應該幫助,因爲您正在使用小批量。但如果解決了這個問題,請修復get_batch()函數。

+0

嘿,謝謝你的回覆, 但是我懷疑是因爲'get_batch'函數已經被抽象出來了,並且由框架在內部處理,它由google編寫。所以除非框架有bug,'get_batch'函數是正確的。 是否有任何其他超參數可以修改以解決我的問題? –

+0

我不確定。其他人盲目調試問題真的很難。但是,這個問題很可能是由一個bug引入的。儘量減少您的整體火車設置爲1MB左右。如果沒有幫助,請確認所有超參數都已正確傳遞,尤其是seq_len之一。 –