我一直在嘗試使用谷歌的基於RNN seq2seq model.基於谷歌TensorFlow seq2seq模型墜毀在訓練
我一直在訓練文摘模型,並在大約大小1GB的文本數據正在餵養。該模型很快填滿了我的整個RAM(8GB),開始填滿交換內存(進一步8GB)並崩潰後,我必須做一個硬關機。
我LSTM網絡的配置如下:
model: AttentionSeq2Seq
model_params:
attention.class: seq2seq.decoders.attention.AttentionLayerDot
attention.params:
num_units: 128
bridge.class: seq2seq.models.bridges.ZeroBridge
embedding.dim: 128
encoder.class: seq2seq.encoders.BidirectionalRNNEncoder
encoder.params:
rnn_cell:
cell_class: GRUCell
cell_params:
num_units: 128
dropout_input_keep_prob: 0.8
dropout_output_keep_prob: 1.0
num_layers: 1
decoder.class: seq2seq.decoders.AttentionDecoder
decoder.params:
rnn_cell:
cell_class: GRUCell
cell_params:
num_units: 128
dropout_input_keep_prob: 0.8
dropout_output_keep_prob: 1.0
num_layers: 1
optimizer.name: Adam
optimizer.params:
epsilon: 0.0000008
optimizer.learning_rate: 0.0001
source.max_seq_len: 50
source.reverse: false
target.max_seq_len: 50
我試圖減少來自32個批次大小爲16,但它仍然沒有幫助。爲了防止我的模型佔用整個RAM並導致崩潰,我應該做出哪些具體的更改? (如減小數據大小,減少堆疊LSTM單元的數量,進一步減少批量大小等)
我的系統運行Python 2.7x,TensorFlow 1.1.0版和CUDA 8.0。該系統具有4GB內存的Nvidia Geforce GTX-1050Ti(768個CUDA內核),該系統擁有8GB內存和另外8GB交換內存。
嘿,謝謝你的回覆, 但是我懷疑是因爲'get_batch'函數已經被抽象出來了,並且由框架在內部處理,它由google編寫。所以除非框架有bug,'get_batch'函數是正確的。 是否有任何其他超參數可以修改以解決我的問題? –
我不確定。其他人盲目調試問題真的很難。但是,這個問題很可能是由一個bug引入的。儘量減少您的整體火車設置爲1MB左右。如果沒有幫助,請確認所有超參數都已正確傳遞,尤其是seq_len之一。 –