2016-04-04 64 views
2

我正在使用LSTM model in Tensorflow
我已經訓練並保存了LSTM模型。現在我要完成最後一項生成句子的任務。 這裏是我的僞代碼:如何在Tensorflow中使用LSTM模型生成示例句子?

# We have already the run_epoch(session, m, data, eval_op, verbose=False) function with fee_dict like this: 
feed_dict = {m.input_data: x, 
      m.targets: y, 
      m.initial_state: state} 
... 
# train and save model 
... 
# load saved model for generating task 
new_sentence = [START_TOKEN] 
# Here I want to generate a sentence until END_TOKEN is generated. 
while new_sentence[-1] != END_TOKEN: 
    logits = get_logits(model, new_sentence) 
    # get argmax(logits) or sample(logits) 
    next_word = argmax(logits) 
    new_sentence.append(next_word) 
print(new_sentence) 

我的問題是:
當訓練,驗證或測試模型我有(偏移了一個輸入養活兩者的輸入及其標籤)進入模型feed_dict字典。但在生成任務中,我只有一個輸入,它是生成語句new_sentence

我怎樣才能建立正確的get_logits功能或全產生功能也?

回答

3

當你訓練你有神經網絡的輸出時,根據你輸出的錯誤計算出錯誤,根據你創建優化器的錯誤來最小化錯誤。

爲了生成一個新的句子,你只需要得到神經網絡(rnn)的輸出。

編輯:

""" 
Placeholders 
""" 

x = tf.placeholder(tf.int32, [batch_size, num_steps], name='input_placeholder') 
y = tf.placeholder(tf.int32, [batch_size, num_steps], name='labels_placeholder') 
init_state = tf.zeros([batch_size, state_size]) 

""" 
RNN Inputs 
""" 

# Turn our x placeholder into a list of one-hot tensors: 
# rnn_inputs is a list of num_steps tensors with shape [batch_size, num_classes] 
x_one_hot = tf.one_hot(x, num_classes) 
rnn_inputs = tf.unpack(x_one_hot, axis=1) 

""" 
Definition of rnn_cell 

This is very similar to the __call__ method on Tensorflow's BasicRNNCell. See: 
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py 
""" 
with tf.variable_scope('rnn_cell'): 
    W = tf.get_variable('W', [num_classes + state_size, state_size]) 
    b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0)) 

def rnn_cell(rnn_input, state): 
    with tf.variable_scope('rnn_cell', reuse=True): 
     W = tf.get_variable('W', [num_classes + state_size, state_size]) 
     b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0)) 
    return tf.tanh(tf.matmul(tf.concat(1, [rnn_input, state]), W) + b) 

state = init_state 
rnn_outputs = [] 
for rnn_input in rnn_inputs: 
    state = rnn_cell(rnn_input, state) 
    rnn_outputs.append(state) 
final_state = rnn_outputs[-1] 

#logits and predictions 
with tf.variable_scope('softmax'): 
    W = tf.get_variable('W', [state_size, num_classes]) 
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0)) 
logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs] 
predictions = [tf.nn.softmax(logit) for logit in logits] 

# Turn our y placeholder into a list labels 
y_as_list = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(1, num_steps, y)] 

#losses and train_step 
losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logit,label) for \ 
      logit, label in zip(logits, y_as_list)] 
total_loss = tf.reduce_mean(losses) 
train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss) 
def train(): 
    with tf.Session() as sess: 
    #load the model 
    training_losses = [] 
    for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)): 
     training_loss = 0 
     training_state = np.zeros((batch_size, state_size)) 
     if verbose: 
      print("\nEPOCH", idx) 
     for step, (X, Y) in enumerate(epoch): 
      tr_losses, training_loss_, training_state, _ = \ 
       sess.run([losses, 
          total_loss, 
          final_state, 
          train_step], 
           feed_dict={x:X, y:Y, init_state:training_state}) 
      training_loss += training_loss_ 
      if step % 100 == 0 and step > 0: 
       if verbose: 
        print("Average loss at step", step, 
          "for last 250 steps:", training_loss/100) 
       training_losses.append(training_loss/100) 
       training_loss = 0 
    #save the model 

def generate_seq(): 
    with tf.Session() as sess: 
    #load the model 
    # load saved model for generating task 
    new_sentence = [START_TOKEN] 
    # Here I want to generate a sentence until END_TOKEN is generated. 
    while new_sentence[-1] != END_TOKEN: 
     logits = sess.run(final_state,{x:np.asarray([new_sentence])}) 
     # get argmax(logits) or sample(logits) 
     next_word = argmax(logits[0]) 
     new_sentence.append(next_word) 
    print(new_sentence) 
+0

怎麼會這個樣子的代碼? –

+0

以上代碼 –

相關問題