2

我在查看TF Slim介紹性文檔,並且從我所瞭解的情況來看,每次運行只需要一批圖像數據(32幅圖像)。很顯然,人們想通過這個循環來訓練許多不同的批次。介紹不包括這一點。這怎麼能正確完成。我想應該有一些方法來指定一個加載批處理函數,它應該在開始批處理訓練事件時自動調用,但我似乎無法在介紹中找到一個簡單的例子。Tensorflow Slim的批量培訓

# Note that this may take several minutes. 

import os 

from datasets import flowers 
from nets import inception 
from preprocessing import inception_preprocessing 

slim = tf.contrib.slim 
image_size = inception.inception_v1.default_image_size 


def get_init_fn(): 
    """Returns a function run by the chief worker to warm-start the training.""" 
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] 

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 

    variables_to_restore = [] 
    for var in slim.get_model_variables(): 
     excluded = False 
     for exclusion in exclusions: 
      if var.op.name.startswith(exclusion): 
       excluded = True 
       break 
     if not excluded: 
      variables_to_restore.append(var) 

    return slim.assign_from_checkpoint_fn(
     os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 
     variables_to_restore) 


train_dir = '/tmp/inception_finetuned/' 

with tf.Graph().as_default(): 
    tf.logging.set_verbosity(tf.logging.INFO) 

    dataset = flowers.get_split('train', flowers_data_dir) 
    images, _, labels = load_batch(dataset, height=image_size, width=image_size) 

    # Create the model, use the default arg scope to configure the batch norm parameters. 
    with slim.arg_scope(inception.inception_v1_arg_scope()): 
     logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 

    # Specify the loss function: 
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 
    slim.losses.softmax_cross_entropy(logits, one_hot_labels) 
    total_loss = slim.losses.get_total_loss() 

    # Create some summaries to visualize the training process: 
    tf.scalar_summary('losses/Total Loss', total_loss) 

    # Specify the optimizer and create the train op: 
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 
    train_op = slim.learning.create_train_op(total_loss, optimizer) 

    # Run the training: 
    final_loss = slim.learning.train(
     train_op, 
     logdir=train_dir, 
     init_fn=get_init_fn(), 
     number_of_steps=2) 


print('Finished training. Last batch loss %f' % final_loss) 
+0

是不是代碼示例中的函數load_batch undefined共享?我不熟悉你的例子,但我會開始閱讀這個功能,以瞭解批處理過程。 – pltrdy

+0

它在這裏給出https://github.com/tensorflow/models/blob/master/slim/slim_walkthough.ipynb但是除了獲得批量外,這沒有任何作用。 –

+0

所以你基本上只需要迭代批次? – pltrdy

回答

1

slim.learning.train函數包含一個訓練循環,所以你給的代碼不會對圖像的多個批次的事實火車。

請參閱here in the source code,其中train_step_fn在while循環內被調用。 train_step(默認值爲train_step_fn)包含行sess.run([train_op, global_step]...),該行實際上在單批圖像上運行訓練操作。

+0

好吧,我在load_batch函數中放了一個print語句,並且訓練了超過1步,發現加載批處理函數只被調用一次,所以這意味着相同的數據被用於多個步驟,因此這個問題。 –

+0

此外,我沒有在調用learning.train時指定load_batch函數,那麼它如何才能「知道」使用它來加載新批次? –

+0

我已經做了更多的研究,看起來有一個隊列可以從每批自動加載的地方建立起來。爲了測試這個,我在這裏有一個相關的問題http://stackoverflow.com/questions/41868871/tensorflow-slim-debugging-during-training。請儘可能評論。 –