2017-06-07 86 views
1

我有訓練DNN網絡的代碼。我不想每次都訓練這個網絡,因爲它使用了太多的時間。我該如何保存模型?如何保存張量流的DNN模型

def train_model(filename, validation_ratio=0.): 
    # define model to be trained 
    columns = [tf.contrib.layers.real_valued_column(str(col), 
                dtype=tf.int8) 
       for col in FEATURE_COLS] 
    classifier = tf.contrib.learn.DNNClassifier(
     feature_columns=columns, 
     hidden_units=[100, 100], 
     n_classes=N_LABELS, 
     dropout=0.3) 

    # load and split data 
    print('Loading training data.') 
    data = load_batch(filename) 
    overall_size = data.shape[0] 
    learn_size = int(overall_size * (1 - validation_ratio)) 
    learn, validation = np.array_split(data, [learn_size]) 
    print('Finished loading data. Samples count = {}'.format(overall_size)) 

    # learning 
    print('Training using batch of size {}'.format(learn_size)) 
    classifier.fit(input_fn=lambda: pipeline(learn), 
        steps=learn_size) 

    if validation_ratio > 0: 
     validate_model(classifier, learn, validation) 

    return classifier 

運行此功能後,我得到一個DNNClassifier我想要保存。

+0

沒有你得到的答案?你能分享解決方案嗎? –

回答