2017-03-03 23 views
0

我使用Estimator和我在循環中訓練模型來提供數據。每一步都是最後一步。每個最後一步都會保存檢查點。我希望避免在每次迭代中保存檢查點以提高培訓的性能(速度)。 我找不到任何信息如何做到這一點。你有什麼想法/建議/解決方案?爲Estimator最後一步不保存檢查點

classifier = Estimator(
    model_fn=cnn_model_fn, 
    model_dir="./temp_model_Adam", 
    config=tf.contrib.learn.RunConfig(
     save_checkpoints_secs=None, 
     save_checkpoints_steps=100, 
     save_summary_steps=None 
    ) 
) 



# Train the model 

for e in range(0, 10): 
    numbers = np.arange(10000) 
    np.random.shuffle(numbers) 
    for step in range(0, 2000): 
     classifier.fit(
      input_fn=lambda: read_images_for_training_as_batch(step, path, 5, numbers), 
      steps=1 
     ) 

回答

0

如今的API得到了改變了一點,但是從我看到你正在使用的飛度(目前列車)方法不正確,你應該把步驟= 2000,有你的輸入功能在你的數據集返回一個迭代。今天你有tf.estimator.inputs.numpy_input_fn可以幫助你,當你有小數據集時,否則你必須使用tf.data.DataSet api。

像這樣的東西(它加載.wav文件):

from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio 
from tensorflow.python.ops import io_ops 
# ... 
def input_fn(num_epochs, batch_size, shuffle=False, mode='training') 

    def input_fn_bound(): 
     def _read_file(fn, label): 
      return io_ops.read_file(fn), label 

     def _decode(data, label): 
      pcm = contrib_audio.decode_wav(data, 
              desired_channels=1, 
              desired_samples=desired_samples) 
      return pcm.audio, label 

     filenames = get_files(mode) 
     classes = get_classes(mode) 
     labels = {'class': np.array(classes)} 
     dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) 

     if shuffle: 
      dataset = dataset.shuffle(buffer_size=len(labels)) 
     dataset = dataset.map(_read_file, num_parallel_calls=num_map_threads) 
     dataset = dataset.map(_decode, num_parallel_calls=num_map_threads) 
     dataset = dataset.map(lambda wav, label: ({'wav': wav}, label)) 

     dataset = dataset.repeat(num_epochs) 
     dataset = dataset.batch(batch_size) 
     dataset = dataset.prefetch(2)  # To load next batch while the first one is being processed on GPU 
     iter = dataset.make_one_shot_iterator() 
     features, labels = iter.get_next() 
     return features, labels 

    return input_fn_bound 

# .... 

estimator.train(input_fn=input_fn(
     num_epoths=None, 
     batch_size=64, 
     shuffle=True, 
     mode='training'), 
    steps=10000)