2017-07-28 20 views
1

我正在試圖在張量流中訓練一個神經網絡。我使用tf.train.batch_join()函數加載數據及其標籤。我這樣做:張量流中的tf.train.batch_join()函數如何工作?

image_batch, label_batch, image_batch_f = tf.train.batch_join(
     images_and_labels, batch_size=batch_size_placeholder, 
     #shapes=[(args.image_size, args.image_size, 3),()], enqueue_many=True, 
     shapes=[(args.image_height, args.image_width, 3),(), (args.image_height, args.image_width, 3)], enqueue_many=True, 
     capacity=4 * nrof_preprocess_threads * args.batch_size, 
     allow_smaller_final_batch=True) 
    image_batch = tf.identity(image_batch, 'image_batch') 
    image_batch = tf.identity(image_batch, 'input') 
    label_batch = tf.identity(label_batch, 'label_batch') 
    image_batch_f = tf.identity(image_batch_f, 'flipped_images_batch') 

在這裏,我得到了三批數據。一批圖像,一批標籤和一批與圖像批次中相同圖像的翻轉圖像。我想提取一批圖像和翻轉圖像的功能。以下各行通過網絡傳遞批量數據。

# Build the inference graph 
    prelogits, _ = network.inference(image_batch, args.keep_probability, 
     phase_train=phase_train_placeholder, feature_dimension=args.embedding_size, 
     weight_decay=args.weight_decay) 


    features = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') 

    #getting the flipped embeddings 
    prelogits_f, _ = network.inference(image_batch_f,args.keep_probability, 
        phase_train=phase_train_placeholder,feature_dimension=args.embedding_size, 
        weight_decay=args.weight_decay,reuse=True) 
    features_flipped_images = tf.nn.l2_normalize(prelogits_f,1,1e-10,name='embeddings_f') 

爲了獲取這兩個功能,我在features和features_flipped_images ops上運行了一個session.run()。這樣的事情:

feed_dict = {phase_train_placeholder:False, batch_size_placeholder:batch_size} 
emb, emb_f = sess.run([features, features_flipped_images],feed_dict=feed_dict) 

我的問題是以下。我猜測,當我在功能上運行會話時,即batch_join函數將派發一批batch_size大小的圖像。但是當我在features_flipped_images上執行session.run()時,該函數還會從batch_join函數中獲取一批翻轉的圖像。在執行features_flipped_images時,batch_join函數是否派發一批新的翻轉圖像?或者它是在執行特徵時生成的同一批翻轉圖像?如果沒有,那我該怎麼做?我想提取一批圖像和一批翻轉圖像的特徵。

回答

0

我的猜測是每次運行[features,features_flipped_images]只會得到同一批數據。讓我們舉個例子:

imgs_batch,labels_batch = tf.train.batch([img, label]...) 

那麼,如果你想看到什麼是批處理:

imgs_data, labels_data = sess.run([imgs_batch, labels_batch]) 

你看,當你運行sess.run([特點,features_flipped_images]很相似,。 )。我不認爲你會得到兩批,否則,imgs_data和labels_data不相互對應。

+0

我不確定它是否是同一批次,因爲如果我將圖像和翻轉圖像的功能連接起來並使用連接功能進行匹配,那麼當它實際上應該改善匹配性能時,我的系統性能會明顯下降。我還不清楚批量加載器。 –

相關問題