2017-04-20 52 views
2

我使用MXnet用於訓練CNN(以R)時保存的模型,我可以訓練模型沒有任何錯誤與下面的代碼:如何使用MXnet

model <- mx.model.FeedForward.create(symbol=network, 
            X=train.iter, 
            ctx=mx.gpu(0), 
            num.round=20, 
            array.batch.size=batch.size, 
            learning.rate=0.1, 
            momentum=0.1, 
            eval.metric=mx.metric.accuracy, 
            wd=0.001, 
            batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100) 
    ) 

但由於這個過程是時間 - 在夜間我在服務器上運行它,並且爲了在完成培訓後使用它,我想保存該模型。

我用:

save(list = ls(), file="mymodel.RData") 

mx.model.save("mymodel", 10) 

但沒有人可以拯救的典範!例如當我加載"mymodel.RData"時,我無法預測測試集的標籤!

另一個例子是,當我加載"mymodel.RData",並嘗試用下面的代碼繪製它:

graph.viz(model$symbol$as.json()) 

我收到以下錯誤:

Error in model$symbol$as.json() : external pointer is not valid 

任何人可以給我節省的解決方案然後加載這個模型以備將來使用?

感謝

+0

@marbel保存模型:你可以看到,如果你能幫助我嗎? – Mohammad

+0

我找到了一個解決方案,我儘快在這裏發佈,我可以測試它:) – Mohammad

回答

2

您可以通過

model <- mx.model.FeedForward.create(symbol=network, 
           X=train.iter, 
           ctx=mx.gpu(0), 
           num.round=20, 
           array.batch.size=batch.size, 
           learning.rate=0.1, 
           momentum=0.1, 
           eval.metric=mx.metric.accuracy, 
           wd=0.001, 
           epoch.end.callback=mx.callback.save.checkpoint("model_prefix") 
           batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100) 
) 
0

保存你的訓練進度快照最好的做法是使用save_snapshot(http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint)爲每一個時代的訓練後回調的一部分。在R中,等價命令可能是mx.callback.save.checkpoint,但我沒有使用R並且不確定它的用法。

使用這些快照還可以讓您充分利用使用AWS Spot market(https://aws.amazon.com/ec2/spot/pricing/)的低成本選項,例如現在可提供16個K80 GPU並提供實例,價格爲3.8美元/小時,與按需價格相比14.4美元。這種80%-90%的折扣在現貨市場中很常見,並且可以優化培訓的速度和成本,只要您正確使用這些快照即可。