2017-09-03 69 views
0

我正在創建一個可區分貓和狗的圖像分類器。我有follwing代碼:Tensorflow中的排名不匹配錯誤

import cv2 
import os 
from tqdm import tqdm 
import numpy as np 
import tensorflow as tf 
img_height = 128 
img_width = 128 

path = "./train" 
# class info 
file = os.listdir(path) 
index = [] 
images = [] 

# image size and channels 
channels = 3 
n_inputs = img_width * img_height * channels 

# First convolutional layer 
conv1_fmaps = 96 # Number of feature maps created by this layer 
conv1_ksize = 4 # kernel size 3x3 
conv1_stride = 2 
conv1_pad = "SAME" 

# Second convolutional layer 
conv2_fmaps = 192 
conv2_ksize = 4 
conv2_stride = 4 
conv2_pad = "SAME" 

# Third layer is a pooling layer 
pool3_fmaps = conv2_fmaps # Isn't it obvious? 

n_fc1 = 192 # Total number of output features 
n_outputs = 2 

with tf.name_scope("inputs"): 
    X = tf.placeholder(tf.float32, shape=[None, img_width, img_height, channels], name="X") 
    X_reshaped = tf.reshape(X, shape=[-1, img_height, img_width, channels]) 
    y = tf.placeholder(tf.int32, shape=[None, 2], name="y") 
conv1 = tf.layers.conv2d(X_reshaped, filters=conv1_fmaps, kernel_size=conv1_ksize, strides=conv1_stride, padding=conv1_pad,  activation=tf.nn.relu, name="conv1") 
conv2 = tf.layers.conv2d(conv1, filters=conv2_fmaps, kernel_size=conv2_ksize,   strides=conv2_stride, padding=conv2_pad, activation=tf.nn.relu, name="conv2") 

n_epochs = 10 
batch_size = 250 

with tf.name_scope("pool3"): 
    pool3 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID") 
    pool3_flat = tf.reshape(pool3, shape=[-1, pool3_fmaps * 8 * 8]) 

with tf.name_scope("fc1"): 
    fc1 = tf.layers.dense(pool3_flat, n_fc1, activation=tf.nn.relu name="fc1") 
with tf.name_scope("output"): 
    logits = tf.layers.dense(fc1, n_outputs, name="output") 
    Y_proba = tf.nn.softmax(logits, name="Y_proba") 

with tf.name_scope("train"): 
    xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,  labels=y)  
    loss = tf.reduce_mean(xentropy) 
    optimizer = tf.train.AdamOptimizer() 
    training_op = optimizer.minimize(loss) 

with tf.name_scope("eval"): 
    correct = tf.nn.in_top_k(logits, y, 1) 
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) 

init = tf.global_variables_initializer() 

with tf.name_scope("init_and_save"): 
    saver = tf.train.Saver() 


def next_batch(num): 
    index = [] 
    images = [] 
# Data set Creation 
    print("Creating batch dataset "+str(num+1)+"...") 
    for f in tqdm(range(num * batch_size, (num+1)*batch_size)): 
     if file[f].find("dog"): 
      index.append(np.array([0, 1])) 
     else: 
      index.append(np.array([1, 0])) 
      image = cv2.imread(path + "/" + file[f]) 
      image = cv2.resize(image, (img_width, img_height), 0, 0, cv2.INTER_LINEAR) 
     # image = image.astype(np.float32) 
     images.append(image) 

    images = np.array(images, dtype=np.uint8) 
    images = images.astype('float32') 
    images = images/255 

    print("\nBatch "+str(num+1)+" creation finished.") 
    # print([images, index]) 
    return [images, index] 

with tf.Session() as sess: 
    init.run() 
    for epoch in range(n_epochs): 
     for iteration in range(25000 // batch_size): 
      X_batch, y_batch = next_batch(iteration) 
      sess.run(training_op, feed_dict={X: X_batch, y: y_batch}) 
      acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch}) 
      print(epoch, "Train accuracy:", acc_train) 
     save_path = saver.save(sess, "./dogvscat_mnist_model.ckpt") 

但我發現了這個錯誤:

ValueError: Rank mismatch: Rank of labels (received 2) should equal rank of logits minus 1 (received 2).

任何人都可以指出問題所在,並幫助我解決這個問題。我對此完全陌生。

回答

1

tf.nn.sparse_softmax_corss_entropy_with_logitsrank(labels) = rank(logits) - 1,所以你需要重新定義標籤佔位符如下

... 
y = tf.placeholder(tf.int32, shape=[None], name="y") 
... 
xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=y)  
... 
X_batch, y_batch = next_batch(iteration) 
y_batch = np.argmax(y_batch, axis=1) 

或者你可以你只是用tf.nn.softmax_cross_entropy_with_logits不改變標籤的佔位符。

xentropy=tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=y) 
+0

謝謝。它正在工作! – tahsin314