我試圖使用tflearn和我自己的數據。ValueError:無法爲張量u'InputData/X:0',形狀爲'(?,32,32,1)'的形狀(64,32,32)提供值'
我有19748個灰度圖像,我想用我的模型進行訓練。我使用了tflearn的Image_Preloader方法來輸入圖像。所有圖像都轉換成32 * 32大小。但是當我開始訓練過程時,我得到這個錯誤「ValueError:無法提供形狀爲'(?,32,32,1)的Tensor u'InputData/X:0'的形狀值(64,32,32) '「
我已經嘗試了一切在我的知識,但我不能解決它,並有類似的問題在stackoverflow中,但他們沒有爲我工作。
這是我的代碼。
from __future__ import division, print_function, absolute_import
import tflearn
import pickle
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
from time import gmtime, strftime
from tflearn.data_utils import image_preloader
import numpy as np
dataset_file = 'noww.txt'
X = np.zeros((19748,32,32,1))
Y = np.zeros((19748,10))
X, Y = image_preloader(dataset_file, image_shape=(32, 32), mode='file', categorical_labels=True, normalize=True)
network = input_data(shape=[None, 32, 32, 1])
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = conv_2d(network, 128, 3, activation='relu')
network = conv_2d(network, 128, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = conv_2d(network, 256, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = fully_connected(network, 1024, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 1024, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 10, activation='softmax')
network = regression(network, optimizer='rmsprop',
loss='categorical_crossentropy',
learning_rate=0.0001)
model = tflearn.DNN(network, checkpoint_path='model_1',
max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=200, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id='model_1')
請幫忙。