2017-08-21 81 views
0

下面的代碼片段是我的代碼。我使用隊列加載訓練數據,並使用Feed加載驗證圖像。隨着訓練過程的進行,訓練損失和訓練的準確性恰到好處。但是,就驗證階段而言,驗證損失和準確性很奇怪。驗證損失太高,無論我運行多少個步驟,驗證的準確性都太低,就像隨機猜測一樣。但是,當我在函數load_validate_img_data中將'is_training'參數設置爲True而不是False時,驗證損失和準確性恰到好處。使用batch_norm有什麼問題嗎?Tensorflow批處理標準導致培訓損失和驗證損失之間的不平衡?

def inference(inputs, 
      num_classes=1000, 
      is_training=True, 
      dropout_keep_prob=0.5, 
      reuse = None, 
      scope='alexnet'): 




     with slim.arg_scope([slim.conv2d, slim.fully_connected], 
         normalizer_fn=slim.batch_norm, 
         activation_fn=tf.nn.relu, 
         biases_initializer=tf.constant_initializer(0.1), 
         weights_regularizer=slim.l2_regularizer(WEIGHT_DECAY), 
         normalizer_params={'is_training': is_training, 
         'decay': 0.95, 'reuse':reuse, 'scope': scope}): 

      with slim.arg_scope([slim.conv2d], padding='SAME'): 

       with slim.arg_scope([slim.max_pool2d], padding='VALID') : 

        with tf.variable_scope(scope, [inputs],reuse = reuse) as sc: 



          net = slim.conv2d(inputs, 32, [3, 3],2, scope='conv1', padding='VALID') 
          net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 

          net = slim.conv2d(net, 64, [3, 3], scope='conv2') 
          net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 

          net = slim.conv2d(net, 128, [2, 2], scope='conv3') 
          net = slim.max_pool2d(net, [2, 2], 2, scope='pool3') 

          net = slim.conv2d(net, 256, [2, 2], scope='conv4') 
          net = slim.max_pool2d(net, [2, 2], 2, scope='pool4') 

          net = slim.conv2d(net, 512, [2, 2], scope='conv5') 
          net = slim.avg_pool2d(net, [2, 2], scope='pool5') 



          net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6') 

          net = slim.conv2d(net, num_classes,[1,1] ,activation_fn = None, normalizer_fn = None, scope='fc7') 
          net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 
          end_points = net 

          return net, end_points 

def get_softmax_loss(logits, labels, name = 'train'): 


    one_hot_labels = slim.one_hot_encoding(labels, LABEL_NUM) 

    softmax_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = one_hot_labels, logits = logits)) 

    vars = tf.trainable_variables() 
    regularization_loss = tf.add_n([tf.nn.l2_loss(v) for v in vars]) * 0.0005 


    total_loss = softmax_loss + regularization_loss 



    return total_loss 

def get_train_op(loss): 


    lr_in_use = tf.Variable(0.01, trainable=False) 
    with tf.name_scope('lr_update'):  
     lr_update = tf.assign(lr_in_use, tf.maximum(lr_in_use*0.5, 0.000001)) 


    optimizer = tf.train.MomentumOptimizer(lr_in_use, 0.9) 


    step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) 
    train_op = slim.learning.create_train_op(loss, optimizer, global_step = step) 



    loss_update = loss 
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
    if update_ops: 
     updates = tf.group(*update_ops) 
     loss_update = control_flow_ops.with_dependencies([updates], loss) 



    return train_op, loss_update, lr_update 

def get_train_acc(logits, labels, name = 'train'): 

    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(logits, 1), labels), tf.float32)) 

    return accuracy 

def load_validate_img_data(): 



    validate_img_root = '~/data/' 

    img_roots = glob(validate_img_root + '*.bmp') 

    validate_img = [] 
    validate_label = [] 
    read_count = 0 
    for root in img_roots: 

     if read_count == 400: 
      break 

     label_root = root.split('/') 
     validate_label.append(label_root[-1][:-4]) 
     validate_img.append(cv2.imread(root)) 

     read_count += 1 



    validate_img = np.array(validate_img).astype(np.float32) 
    validate_label = np.array(validate_label).astype(np.int64) 


    with tf.name_scope('validate_input'): 
     input_imgs = tf.placeholder(tf.float32, shape = (100, ORIGINAL_SIZE[0], ORIGINAL_SIZE[1], CHANNELS), name = 'imgs') 
     input_labels = tf.placeholder(tf.int64, shape = (100), name = 'labels') 
    transfer_input_imgs = ut._resize_crop_img(input_imgs, RESIZE_TO, RESIZE_TO, process_type = 'validate') 



    logits, out_data = face_train.inference(transfer_input_imgs, num_classes=LABEL_NUM, is_training = False, reuse = True) 


    validate_accuracy = get_train_acc(logits, input_labels, name = 'validate') 
    validate_loss = get_softmax_loss(logits, input_labels, name = 'validate') 



    return validate_img, validate_label, input_imgs, input_labels, validate_accuracy, validate_loss 

with tf.Graph().as_default(): 



    images, labels = ut._load_batch_t(data_dir, ORIGINAL_SIZE, CHANNELS, BATCH_SIZE, RESIZE_TO, RESIZE_TO) 


    logits= face_train.inference(images, num_classes=LABEL_NUM) 


    accuracy = get_train_acc(logits, labels) 

    total_loss = get_softmax_loss(logits, labels) 
    train_op, loss_update, lr_update = get_train_op(total_loss) 




    validate_img, validate_label, img_placeholer,label_placeholder, validate_accuracy, validate_loss = load_validate_img_data() 



    with tf.Session() as sess: 



     sess.run(tf.global_variables_initializer()) 
     sess.run(tf.local_variables_initializer()) 


     saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) 





     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 

     total_step = 0 
     epoc_step = int(SAMPLE_NUM/BATCH_SIZE) 
     for epoc in range(EPOC_NUM): 
      for step in range(epoc_step): 



       _ = sess.run([train_op]) 


       if total_step % 20 == 0: 
        loss, train_ac =sess.run([loss_update, accuracy]) 
        print ('epoc : %d, step : %d, train_loss : %.2f, train_acc: %.3f' %(epoc, step, loss, train_ac)) 




       if total_step % 200 == 0: 
        all_va_acc = 0 
        all_va_loss = 0 

        for i in range(4): 

         feed_dict = {img_placeholer: validate_img[i*100 : (i+1)*100], \ 
         label_placeholder: validate_label[i*100 : (i+1)*100]} 

         va_acc, va_loss, summary_val= sess.run([validate_accuracy, validate_loss, merged_val ], feed_dict = feed_dict) 
         all_va_acc += va_acc 
         all_va_loss += va_loss 


        print ('validate_accuracy: %.2f, validate_loss: %.2f' % (all_va_acc/4.0, all_va_loss/4.0)) 



       total_step += 1 




     coord.request_stop() 
     coord.join(threads) 

回答

0

推理過程中,一批規範moving average meanmoving average variance被使用,所以你需要設置參數is_training爲假。

def inference(inputs, 
     num_classes=1000, 
     is_training=False, 
     dropout_keep_prob=0.5, 
     reuse = None, 
     scope='alexnet'): 
+0

我已經弄清楚發生了什麼,我犯了一個愚蠢的錯誤。我使用規範化圖像數據(0,1)進行訓練,但使用未規範化圖像數據(0,255)進行測試。但是,仍然感謝您的回覆! –

相關問題