8
我已經盡力遵循關於神經網絡結構的在線指南,但我必須缺少一些基本的東西。給定一組經過訓練的權重及其偏差,我想簡單地用這些權重手動預測輸入,而不使用預測方法。如何使用權重手動預測神經網絡中的數據與karas
使用帶有keras的MNIST圖像我試圖手動編輯我的數據以包含偏差的額外特徵,但是這種努力似乎沒有提供比沒有偏差更好的圖像準確性,使用keras預測方法。我的代碼與我的嘗試一起在下面。
請注意接近底部的兩條評論,用於將keras方法預測用於準確的圖像表示,然後我嘗試通過手動獲取權重和添加偏差來嘗試。
from keras.datasets import mnist
import numpy as np
import time
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
from matplotlib import pyplot as plt
comptime=time.time()
with tf.device('/cpu:0'):
tf.placeholder(tf.float32, shape=(None, 20, 64))
seed = 7
np.random.seed(seed)
model = Sequential()
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32')/255.
priorShape_x_train=x_train.shape #prior shape of training set
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
x_train_shaped=x_train
model.add(Dense(32, input_dim=784, init='uniform', activation='relu'))
model.add(Dense(784, init='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
model.fit(x_train[1:2500], x_train[1:2500], nb_epoch=10)
#proper keras prediction
prediction_real=model.predict(x_train[57:58])
prediction_real=prediction_real.reshape((28,28))
#manual weight prediction attempt
x_train=np.hstack([x_train,np.zeros(x_train.shape[0]).reshape(x_train.shape[0],1)]) #add extra column for bias
x_train[:,-1]=1 #add placeholder as 1
weights=np.vstack([model.get_weights()[0],model.get_weights()[1]]) #add trained weights as extra row vector
prediction=np.dot(x_train,weights) #now take dot product.. repeat pattern for next layer
prediction=np.hstack([prediction,np.zeros(prediction.shape[0]).reshape(prediction.shape[0],1)])
prediction[:,-1]=1
weights=np.vstack([model.get_weights()[2],model.get_weights()[3]])
prediction=np.dot(prediction,weights)
prediction=prediction.reshape(priorShape_x_train)
plt.imshow(prediction[57], interpolation='nearest',cmap='gray')
plt.savefig('myprediction.png') #my prediction, not accurate
plt.imshow(prediction_real,interpolation='nearest',cmap='gray')
plt.savefig('realprediction.png') #in-built keras method, accurate