2016-05-29 17 views
1

對不起,作爲一個新手火炬,但我保證通過文件和互聯網搜索了很多。如何在火炬中獲取/設置模型的重量(增量)?

有兩個主要需求我需要, 第一個是獲得一個或多個批次訓練後的重量增量, 第二個是將新的權重設置爲模型。

這意味着我想通過我自己的方法(使用外部庫)更新權重, 是否可以在火炬中實現?

看來火炬有一個摘要module class [1]但它的接口並不能滿足我所有的需求。

[1] https://github.com/torch/nn/blob/master/doc/module.md#nn.Module

回答

2

最後,我參照幾個同事找到了答案。

正確理解getParameters() [1]是解決問題的關鍵。 getParameters()將得到平坦的parameters(權重)和gradParameters(權重增量),更重要的是,它是一個內存轉換,應該只記錄一次。

這意味着getParameters()的返回值就是我們想要的值,返回值的更改將反映到更新權重的原始模型中。

因此,我們不僅可以通過由getParameters()返回的parameters獲得平坦權重,還可以將權重設置爲parameters:copy()。我們絕對可以使用其他torch.Tensor()方法來修改權重。

[1] https://github.com/torch/nn/blob/master/doc/module.md#flatparameters-flatgradparameters-getparameters