2016-11-24 45 views
4

我正在試用Tensorflow's rnn example。 由於在開始時出現了一些問題,我可以運行該示例來訓練ptb,現在我已經訓練了一個模型。如何使用Tensorflow的PTB模型示例?

我該如何正確使用模型來創建句子,而不必再次訓練?

我與像python ptb_word_lm.py --data_path=/home/data/ --model medium --save_path=/home/medium

命令運行它是否有一個例子某處如何使用訓練的模型造句?

+0

你能告訴/告訴我的結果是通過命令「--save_path /家/媒介」目前什麼工作? – LKM

回答

5

1.增加在PTBModel:__init__()函數的最後一行下面的代碼:

self._output_probs = tf.nn.softmax(logits) 

2.增加在PTBModel以下功能:

@property 
def output_probs(self): 
    return self._output_probs 

3.Try運行下面的代碼:

raw_data = reader.ptb_raw_data(FLAGS.data_path) 
train_data, valid_data, test_data, vocabulary, word_to_id, id_to_word = raw_data 

eval_config = get_config() 
eval_config.batch_size = 1 
eval_config.num_steps = 1 

sess = tf.Session() 

initializer = tf.random_uniform_initializer(-eval_config.init_scale, 
              eval_config.init_scale) 
with tf.variable_scope("model", reuse=None, initializer=initializer): 
    mtest = PTBModel(is_training=False, config=eval_config) 

sess.run(tf.initialize_all_variables()) 

saver = tf.train.Saver() 

ckpt = tf.train.get_checkpoint_state('/home/medium') # __YOUR__MODEL__SAVE__PATH__ 
if ckpt and gfile.Exists(ckpt.model_checkpoint_path): 
    msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path 
    print(msg) 
    saver.restore(sess, ckpt.model_checkpoint_path) 

def pick_from_weight(weight, pows=1.0): 
    weight = weight**pows 
    t = np.cumsum(weight) 
    s = np.sum(weight) 
    return int(np.searchsorted(t, np.random.rand(1) * s)) 

while True: 
    number_of_sentences = 10 # generate 10 sentences one time 
    sentence_cnt = 0 
    text = '\n' 
    end_of_sentence_char = word_to_id['<eos>'] 
    input_char = np.array([[end_of_sentence_char]]) 
    state = sess.run(mtest.initial_state) 
    while sentence_cnt < number_of_sentences: 
     feed_dict = {mtest.input_data: input_char, 
        mtest.initial_state: state} 
     probs, state = sess.run([mtest.output_probs, mtest.final_state], 
             feed_dict=feed_dict) 
     sampled_char = pick_from_weight(probs[0]) 
     if sampled_char == end_of_sentence_char: 
      text += '.\n' 
      sentence_cnt += 1 
     else: 
      text += ' ' + id_to_word[sampled_char] 
     input_char = np.array([[sampled_char]]) 
    print(text) 
    raw_input('press any key to continue ...') 
+0

我收到一個錯誤:運行此代碼時,''PTBModel'對象沒有屬性'_output_probs'。 – smith

相關問題