2017-10-15 39 views
0

我正在嘗試使用模型函數構建一個非常簡單的模型,如下所示,其中模型函數的輸入和輸出是[img,labels]和損失。 我很困惑,爲什麼這個代碼不工作,如果輸出不能是損失。 Model函數應該如何工作以及我們應該在什麼時候使用Model函數?謝謝。在keras模型函數中使用損失

sess = tf.Session() 
K.set_session(sess) 
K.set_learning_phase(1) 
img = Input((784,),name='img') 
labels = Input((10,),name='labels') 
# img = tf.placeholder(tf.float32, shape=(None, 784)) 
# labels = tf.placeholder(tf.float32, shape=(None, 10)) 

x = Dense(128, activation='relu')(img) 
x = Dropout(0.5)(x) 
x = Dense(128, activation='relu')(x) 
x = Dropout(0.5)(x) 
preds = Dense(10, activation='softmax')(x) 

from keras.losses import binary_crossentropy 
#loss = tf.reduce_mean(categorical_crossentropy(labels, preds)) 
loss = binary_crossentropy(labels, preds) 
print(type(loss)) 
model = Model([img,labels], loss, name='squeezenet') 
model.summary() 
+0

損失是通過compile()提供的。您可以從[documentation](https://keras.io/getting-started/functional-api-guide/)中找到一些示例。 –

回答

3

作爲@宇陽指出的,損失與compile()指定。 如果你仔細想想,這是有道理的,因爲你的模型的真實輸出是你的預測,而不是損失,損失只用於訓練模型。

您的網絡的工作示例:

import keras 
from keras.optimizers import Adam 
from keras.models import Model 
from keras.layers import Input, Dense, Dropout 
from keras.losses import categorical_crossentropy 

img = Input((784,),name='img') 

x = Dense(128, activation='relu')(img) 
x = Dropout(0.5)(x) 
x = Dense(128, activation='relu')(x) 
x = Dropout(0.5)(x) 
preds = Dense(10, activation='softmax')(x) 

model = Model(inputs=img, outputs=preds, name='squeezenet') 


model.compile(optimizer=Adam(), 
       loss=categorical_crossentropy, 
       metrics=['acc']) 

model.summary() 

輸出:

_________________________________________________________________ 
Layer (type)     Output Shape    Param # 
================================================================= 
img (InputLayer)    (None, 784)    0   
_________________________________________________________________ 
dense_32 (Dense)    (None, 128)    100480  
_________________________________________________________________ 
dropout_21 (Dropout)   (None, 128)    0   
_________________________________________________________________ 
dense_33 (Dense)    (None, 128)    16512  
_________________________________________________________________ 
dropout_22 (Dropout)   (None, 128)    0   
_________________________________________________________________ 
dense_34 (Dense)    (None, 10)    1290  
================================================================= 
Total params: 118,282 
Trainable params: 118,282 
Non-trainable params: 0 
_________________________________________________________________ 

隨着MNIST數據集:

from keras.datasets import mnist 
from keras.utils import to_categorical 

(x_train, y_train), (x_test, y_test) = mnist.load_data() 

x_train = x_train.reshape(-1, 784) 
y_train = to_categorical(y_train, num_classes=10) 
x_test = x_test.reshape(-1, 784) 
y_test = to_categorical(y_test, num_classes=10) 

model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test)) 

輸出:

Train on 60000 samples, validate on 10000 samples 
Epoch 1/10 
60000/60000 [==============================] - 4s - loss: 12.2797 - acc: 0.2360 - val_loss: 11.0902 - val_acc: 0.3116 
Epoch 2/10 
60000/60000 [==============================] - 4s - loss: 10.4161 - acc: 0.3527 - val_loss: 8.7122 - val_acc: 0.4589 
Epoch 3/10 
60000/60000 [==============================] - 4s - loss: 9.5797 - acc: 0.4051 - val_loss: 8.9226 - val_acc: 0.4460 
Epoch 4/10 
60000/60000 [==============================] - 4s - loss: 9.2017 - acc: 0.4285 - val_loss: 8.0564 - val_acc: 0.4998 
Epoch 5/10 
60000/60000 [==============================] - 4s - loss: 8.8558 - acc: 0.4501 - val_loss: 8.0878 - val_acc: 0.4980 
Epoch 6/10 
60000/60000 [==============================] - 5s - loss: 8.8239 - acc: 0.4521 - val_loss: 8.2495 - val_acc: 0.4880 
Epoch 7/10 
60000/60000 [==============================] - 4s - loss: 8.7842 - acc: 0.4547 - val_loss: 7.7146 - val_acc: 0.5211 
Epoch 8/10 
60000/60000 [==============================] - 4s - loss: 8.7395 - acc: 0.4575 - val_loss: 7.7944 - val_acc: 0.5163 
Epoch 9/10 
60000/60000 [==============================] - 5s - loss: 8.7109 - acc: 0.4593 - val_loss: 7.8235 - val_acc: 0.5145 
Epoch 10/10 
60000/60000 [==============================] - 4s - loss: 8.4927 - acc: 0.4729 - val_loss: 7.5933 - val_acc: 0.5288 
相關問題