2015-04-30 25 views
5

如何在Encog 3.4(Github目前正在開發的版本)中暫停遺傳算法?如何在Encog中暫停/序列化遺傳算法?

我正在使用Encog的Java版本。

我想修改農曆例子,伴隨着Encog。我想暫停/序列化遺傳算法,然後在後續階段繼續/反序列化。

當我打電話給train.pause();它只是返回null - 這是從代碼很明顯,因爲該方法始終返回null

我認爲這將是非常直接的,因爲可以有一個場景,我想訓練一個神經網絡,用它進行一些預測,然後繼續用遺傳算法訓練,因爲我在恢復之前獲得更多的數據有更多預測 - 無需從頭開始重新開始訓練。

請注意,我並非試圖序列化或持久化一個神經網絡,而是整個遺傳算法。

回答

4

並非恩戈的所有培訓師都支持簡單的暫停/恢復。如果他們不支持它,他們會返回null,就像這樣。遺傳算法訓練器比支持暫停/恢復的簡單傳播訓練器複雜得多。爲了保存遺傳算法的狀態,您必須保存整個羣體以及評分函數(可能或不可序列化)。我修改了Lunar Lander示例,向您展示如何保存/重新加載您的神經網絡羣體來完成此操作。你可以看到它訓練了50次迭代,然後往返(加載/保存)遺傳算法,然後再訓練50次。

package org.encog.examples.neural.lunar; 

import java.io.File; 
import java.io.IOException; 

import org.encog.Encog; 
import org.encog.engine.network.activation.ActivationTANH; 
import org.encog.ml.MLMethod; 
import org.encog.ml.MLResettable; 
import org.encog.ml.MethodFactory; 
import org.encog.ml.ea.population.Population; 
import org.encog.ml.genetic.MLMethodGeneticAlgorithm; 
import org.encog.ml.genetic.MLMethodGenomeFactory; 
import org.encog.neural.networks.BasicNetwork; 
import org.encog.neural.pattern.FeedForwardPattern; 
import org.encog.util.obj.SerializeObject; 

public class LunarLander { 

    public static BasicNetwork createNetwork() 
    { 
     FeedForwardPattern pattern = new FeedForwardPattern(); 
     pattern.setInputNeurons(3); 
     pattern.addHiddenLayer(50); 
     pattern.setOutputNeurons(1); 
     pattern.setActivationFunction(new ActivationTANH()); 
     BasicNetwork network = (BasicNetwork)pattern.generate(); 
     network.reset(); 
     return network; 
    } 

    public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga) throws IOException 
    { 
     ga.getGenetic().getPopulation().setGenomeFactory(null); 
     SerializeObject.save(new File(file),ga.getGenetic().getPopulation()); 
    } 

    public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException { 
     Population pop = (Population) SerializeObject.load(new File(filename)); 
     pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){ 
      @Override 
      public MLMethod factor() { 
       final BasicNetwork result = createNetwork(); 
       ((MLResettable)result).reset(); 
       return result; 
      }},pop)); 

     MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){ 
      @Override 
      public MLMethod factor() { 
       return createNetwork(); 
      }},new PilotScore(),1); 

     result.getGenetic().setPopulation(pop); 

     return result; 
    } 


    public static void main(String args[]) 
    { 
     BasicNetwork network = createNetwork(); 

     MLMethodGeneticAlgorithm train; 


     train = new MLMethodGeneticAlgorithm(new MethodFactory(){ 
      @Override 
      public MLMethod factor() { 
       final BasicNetwork result = createNetwork(); 
       ((MLResettable)result).reset(); 
       return result; 
      }},new PilotScore(),500); 

     try { 
      int epoch = 1; 

      for(int i=0;i<50;i++) { 
       train.iteration(); 
       System.out 
         .println("Epoch #" + epoch + " Score:" + train.getError()); 
       epoch++; 
      } 
      train.finishTraining(); 

      // Round trip the GA and then train again 
      LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train); 
      train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin"); 

      // Train again 
      for(int i=0;i<50;i++) { 
       train.iteration(); 
       System.out 
         .println("Epoch #" + epoch + " Score:" + train.getError()); 
       epoch++; 
      } 
      train.finishTraining(); 

     } catch(IOException ex) { 
      ex.printStackTrace(); 
     } catch (ClassNotFoundException e) { 
      // TODO Auto-generated catch block 
      e.printStackTrace(); 
     } 

     int epoch = 1; 

     for(int i=0;i<50;i++) { 
      train.iteration(); 
      System.out 
        .println("Epoch #" + epoch + " Score:" + train.getError()); 
      epoch++; 
     } 
     train.finishTraining(); 

     System.out.println("\nHow the winning network landed:"); 
     network = (BasicNetwork)train.getMethod(); 
     NeuralPilot pilot = new NeuralPilot(network,true); 
     System.out.println(pilot.scorePilot()); 
     Encog.getInstance().shutdown(); 
    } 
} 
+0

非常感謝你,這解釋和展示了這個概念。 但我加載遺傳算法後,我沒有得到一個好的答案,似乎我必須在調用'train.getMethod()'之前訓練遺傳算法至少一次。 並且關於你的答案的一個筆記,我認爲你在重新加載50次後訓練神經網絡50次(即100次)。 – Tmr

+1

是的,上面的例子訓練了50次,然後保存/加載和訓練50次以上。遺傳算法使用了全部的神經網絡,調用getMethod只是從人口中返回頂部的神經網絡,所以在GA訓練之前它不會太有用。 – JeffHeaton