2017-07-09 46 views
0

我想與keras一起使用交替更新規則。 也就是說每批我想打電話給一個常規的基於漸變的步驟,然後調用一個自定義步驟。與keras的自定義交替更新規則

我想通過繼承優化器或回調來實現它(並使用批次調用)。但是,兩者都不會,因爲它們都缺少批處理數據和批處理標籤(而且我都需要)。

關於如何實現keras自定義交替更新的任何想法?

如果需要的話,我不介意直接打電話tensorflow具體方法,只要我能保持使用該項目與keras框架包裹(與model.fit,model.predict ..)

回答

0

嘗試創建一個自定義的回調

import keras.callbacks as callbacks 

class JSONMetrics(callbacks.Callback): 

_model  = None 
_each_epoch = None 
_metrics = None 
_epoch  = None 
_file_json = None 

def __init__(self,model,each_epoch,logger=None): 

    self._file_json = "file_log.json" 
    self._model  = model 
    self._each_epoch= each_epoch 
    self._epoch  = 0 
    self._metrics = {'loss':[], 'acc':[]} 

def on_epoch_begin(self, epoch, logs): 
    # print('Epoch {0} begin'.format(epoch)) 
    try: 
     with open(self._file_json, 'r') as f: 
      self._metrics = json.load(f) 

def on_epoch_end(self, epoch, logs): 
    self._logger.info('Nemesis: Epoch {0} end'.format(epoch)) 

    self._metrics['loss'].append(logs.get('loss')) 
    self._metrics['acc'].append(logs.get('acc')) 
    with open(self._file_json, 'w') as f: 
     data = json.dump(self._metrics, f) 

    if self._epoch % self._each_epoch == 0: 

     file_name = 'weights%08d.h5' % self._epoch 
     #print('Saving weights at {0} file'.format(file_name)) 
     self._model.save_weights(file_name) 

    self._epoch += 1 

可以喚起self.model解決您的問題,並保存例如ACC和損失。