在keras序列模型中繪製模型損失和模型準確性似乎很簡單。但是如果我們將數據拆分爲X_train
,Y_train
, X_test
, Y_test
,並且使用交叉驗證,他們又怎麼能夠繪製?我收到錯誤,因爲它找不到'val_acc'
。這意味着我不能在測試集上繪製結果。從history.history繪製模型損失和模型準確性Keras sequential
這裏是我的代碼:
# Create the model
def create_model(neurons = 379, init_mode = 'uniform', activation='relu', inputDim = 8040, dropout_rate=1.1, learn_rate=0.001, momentum=0.7, weight_constraint=6): #weight_constraint=
model = Sequential()
model.add(Dense(neurons, input_dim=inputDim, kernel_initializer=init_mode, activation=activation, kernel_constraint=maxnorm(weight_constraint), kernel_regularizer=regularizers.l2(0.002))) #, activity_regularizer=regularizers.l1(0.0001))) # one inner layer
#model.add(Dense(200, input_dim=inputDim, activation=activation)) # second inner layer
#model.add(Dense(60, input_dim=inputDim, activation=activation)) # second inner layer
model.add(Dropout(dropout_rate))
model.add(Dense(1, activation='sigmoid'))
optimizer = RMSprop(lr=learn_rate)
# compile model
model.compile(loss='binary_crossentropy', optimizer='RmSprop', metrics=['accuracy']) #weight_constraint=weight_constraint
return model
model = create_model() #weight constraint= 3 or 4
seed = 7
# Define k-fold cross validation test harness
kfold = StratifiedKFold(n_splits=3, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X_train, Y_train):
print("TRAIN:", train, "VALIDATION:", test)
# Fit the model
history = model.fit(X_train, Y_train, epochs=40, batch_size=50, verbose=0)
# Plot Model Loss and Model accuracy
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc']) # RAISE ERROR
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss']) #RAISE ERROR
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
我將不勝感激它的一些必要的修改,以獲得這些地塊也爲測試。
我們可以看到您收到的錯誤嗎?關於'val_acc'的一些事情? – cosinepenguin
當然。 plt.plot(history.history ['val_acc']) KeyError:'val_acc'。如果我刪除線條plt.plot(history.history ['val_acc']),它會返回每個交叉驗證數據集(火車)的圖表。 –