1
我有訓練DNN網絡的代碼。我不想每次都訓練這個網絡,因爲它使用了太多的時間。我該如何保存模型?如何保存張量流的DNN模型
def train_model(filename, validation_ratio=0.):
# define model to be trained
columns = [tf.contrib.layers.real_valued_column(str(col),
dtype=tf.int8)
for col in FEATURE_COLS]
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=columns,
hidden_units=[100, 100],
n_classes=N_LABELS,
dropout=0.3)
# load and split data
print('Loading training data.')
data = load_batch(filename)
overall_size = data.shape[0]
learn_size = int(overall_size * (1 - validation_ratio))
learn, validation = np.array_split(data, [learn_size])
print('Finished loading data. Samples count = {}'.format(overall_size))
# learning
print('Training using batch of size {}'.format(learn_size))
classifier.fit(input_fn=lambda: pipeline(learn),
steps=learn_size)
if validation_ratio > 0:
validate_model(classifier, learn, validation)
return classifier
運行此功能後,我得到一個DNNClassifier
我想要保存。
沒有你得到的答案?你能分享解決方案嗎? –