2016-10-04 60 views
2

我訓練了Tensorflow Cifar10模型,我想用它自己的單個圖像(32 * 32,jpg/png)餵食它。我如何測試自己的圖像到Censar-10的Tensorflow教程?

我希望看到每個標籤的標籤和概率作爲輸出,但有關於這個有些麻煩我..

搜索堆棧溢出後,我發現了一些職位是this我修改cifar10_eval.py。

但它根本不起作用。

錯誤消息:

InvalidArgumentErrorTraceback (most recent call last) in() ----> 1 evaluate()

in evaluate() 86 # Restores from checkpoint 87 print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path) ---> 88 saver.restore(sess, ckpt.model_checkpoint_path) 89 # Assuming model_checkpoint_path looks something like: 90 # /my-favorite-path/cifar10_train/model.ckpt-0,

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in restore(self, sess, save_path) 1127 raise ValueError("Restore called with invalid save path %s" % save_path)
1128 sess.run(self.saver_def.restore_op_name, -> 1129 {self.saver_def.filename_tensor_name: save_path}) 1130 1131 @staticmethod

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata) 380 try: 381 result = self._run(None, fetches, feed_dict, options_ptr, --> 382 run_metadata_ptr) 383 if run_metadata: 384 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata) 653 movers = self._update_with_movers(feed_dict_string, feed_map) 654 results = self._do_run(handle, target_list, unique_fetches, --> 655 feed_dict_string, options, run_metadata) 656 657 # User may have fetched the same tensor multiple times, but we

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 721 if handle is None: 722 return self._do_call(_run_fn, self._session, feed_dict, fetch_list, --> 723 target_list, options, run_metadata) 724 else: 725 return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args) 741 except KeyError: 742 pass --> 743 raise type(e)(node_def, op, message) 744 745 def _extend_graph(self):

InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [18,384] rhs shape= [2304,384] [[Node: save/Assign_5 = Assign[T=DT_FLOAT, _class=["loc:@local3/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](local3/weights, save/restore_slice_5)]]

做任何與Cifar10幫助將不勝感激。

下面是實現的代碼到目前爲止與編譯問題:

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

from datetime import datetime 
import math 
import time 

import numpy as np 
import tensorflow as tf 
import cifar10 

FLAGS = tf.app.flags.FLAGS 

tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval', 
          """Directory where to write event logs.""") 
tf.app.flags.DEFINE_string('eval_data', 'test', 
          """Either 'test' or 'train_eval'.""") 
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', 
          """Directory where to read model checkpoints.""") 
tf.app.flags.DEFINE_integer('eval_interval_secs', 5, 
          """How often to run the eval.""") 
tf.app.flags.DEFINE_integer('num_examples', 1, 
          """Number of examples to run.""") 
tf.app.flags.DEFINE_boolean('run_once', False, 
         """Whether to run eval only once.""") 

def eval_once(saver, summary_writer, top_k_op, summary_op): 
    """Run Eval once. 

    Args: 
    saver: Saver. 
    summary_writer: Summary writer. 
    top_k_op: Top K op. 
    summary_op: Summary op. 
    """ 
    with tf.Session() as sess: 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
     # Restores from checkpoint 
     saver.restore(sess, ckpt.model_checkpoint_path) 
     # Assuming model_checkpoint_path looks something like: 
     # /my-favorite-path/cifar10_train/model.ckpt-0, 
     # extract global_step from it. 
     global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    else: 
     print('No checkpoint file found') 
     return 
    print("Check point : %s" % ckpt.model_checkpoint_path) 

    # Start the queue runners. 
    coord = tf.train.Coordinator() 
    try: 
     threads = [] 
     for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 
     threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 
             start=True)) 

     num_iter = int(math.ceil(FLAGS.num_examples/FLAGS.batch_size)) 
     true_count = 0 # Counts the number of correct predictions. 
     total_sample_count = num_iter * FLAGS.batch_size 
     step = 0 
     while step < num_iter and not coord.should_stop(): 
     predictions = sess.run([top_k_op]) 
     true_count += np.sum(predictions) 
     step += 1 

     # Compute precision @ 1. 
     precision = true_count/total_sample_count 
     print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) 

     summary = tf.Summary() 
     summary.ParseFromString(sess.run(summary_op)) 
     summary.value.add(tag='Precision @ 1', simple_value=precision) 
     summary_writer.add_summary(summary, global_step) 
    except Exception as e: # pylint: disable=broad-except 
     coord.request_stop(e) 

    coord.request_stop() 
    coord.join(threads, stop_grace_period_secs=10) 


def evaluate(): 
    """Eval CIFAR-10 for a number of steps.""" 
    with tf.Graph().as_default() as g: 
    # Get images and labels for CIFAR-10. 
    eval_data = FLAGS.eval_data == 'test' 
#  images, labels = cifar10.inputs(eval_data=eval_data) 

    # TEST CODE 
    img_path = "/TEST_IMAGEPATH/image.png" 
    input_img = tf.image.decode_png(tf.read_file(img_path), channels=3) 
    casted_image = tf.cast(input_img, tf.float32) 

    reshaped_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 
    float_image = tf.image.per_image_withening(reshaped_image) 
    images = tf.expand_dims(reshaped_image, 0) 

    logits = cifar10.inference(images) 
    _, top_k_pred = tf.nn.top_k(logits, k=1) 


    with tf.Session() as sess: 
     saver = tf.train.Saver() 
     ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path) 
      saver.restore(sess, ckpt.model_checkpoint_path) 
      global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
     else: 
      print('No checkpoint file found') 
      return 

     print("Check point : %s" % ckpt.model_checkpoint_path) 
     top_indices = sess.run([top_k_pred]) 
     print ("Predicted ", top_indices[0], " for your input image.") 

evaluate() 

回答

2

視頻https://youtu.be/d9mSWqfo0Xw顯示單個圖像進行分類的例子。

當網絡已經通過python cifar10_train.py訓練後,我們評估CIFAR-10數據庫的個人圖像deer6.png和火柴盒自己的照片。 TF教程的原始源代碼最重要的修改如下:

首先有必要將這些圖像轉換爲cifar10_input.py可以讀取的二進制形式。

:它可以很容易地通過使用代碼片段,可以在 How to create dataset similar to cifar-10

然後以讀取轉換後的圖像(稱爲input.bin),我們需要修改功能輸入()在cifar10_input.py找到做

else: 
    #filenames = [os.path.join(data_dir, 'test_batch.bin')] 
    filenames = [os.path.join(data_dir, 'input.bin')] 

(DATA_DIR等於 './')

終於搞定,我們有修改功能eval_once()在源cifar10_eval.py標籤:

 #while step < num_iter and not coord.should_stop(): 
     # predictions = sess.run([top_k_op]) 
     print(sess.run(logits[0])) 
     classification = sess.run(tf.argmax(logits[0], 0)) 
     cifar10classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] 
     print(cifar10classes[classification]) 

     #true_count += np.sum(predictions) 
     step += 1 

     # Compute precision @ 1. 
     precision = true_count/total_sample_count 
     # print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) 

當然還有一些小修改,你需要做。

相關問題