2017-10-22 59 views
0
def model_fn(features, labels, mode, params): 
    """Model function for Estimator.""" 

    # Connect the first hidden layer to input layer 
    # (features["x"]) with relu activation 
    first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu) 

    # Connect the second hidden layer to first hidden layer with relu 
    second_hidden_layer = tf.layers.dense(
     first_hidden_layer, 10, activation=tf.nn.relu) 

    # Connect the output layer to second hidden layer (no activation fn) 
    output_layer = tf.layers.dense(second_hidden_layer, 1) 

    # Reshape output layer to 1-dim Tensor to return predictions 
    predictions = tf.reshape(output_layer, [-1]) 

    # Provide an estimator spec for `ModeKeys.PREDICT`. 
    if mode == tf.estimator.ModeKeys.PREDICT: 
    return tf.estimator.EstimatorSpec(
     mode=mode, 
     predictions={"ages": predictions}) 

    # Calculate loss using mean squared error 
    loss = tf.losses.mean_squared_error(labels, predictions) 

    # Calculate root mean squared error as additional eval metric 
    eval_metric_ops = { 
     "rmse": tf.metrics.root_mean_squared_error(
      tf.cast(labels, tf.float64), predictions) 
    } 

    optimizer = tf.train.GradientDescentOptimizer(
     learning_rate=params["learning_rate"]) 
    train_op = optimizer.minimize(
     loss=loss, global_step=tf.train.get_global_step()) 

    # Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes. 
    return tf.estimator.EstimatorSpec(
     mode=mode, 
     loss=loss, 
     train_op=train_op, 
     eval_metric_ops=eval_metric_ops) 

以上是Tensorflow的Estimator使用的model_fn的示例。在Tensorflow的估計器中,model_fn被多次調用時它是如何工作的?

正如教程中提到的,這個model_fn可以在不同的上下文中調用(train,predict,evaluate)。不過,我有點糊塗了,因爲每次model_fn被調用時,而不是重用現有的圖,它似乎創建一個新的圖形(或創建圖中的新節點)

例如,首先我在TRAIN模式下調用model_fn,然後用PREDICT模式調用model_fn。我怎樣才能確保PREDICT正在重新使用訓練值的權重?

回答

相關問題