3
我有一個rnn模型已經通過tensorflow的ptb_word對ptb example進行了培訓。 貝婁我有一個代碼,我試圖打印幾個例子來測試訓練的模型。當我在生產線上運行此代碼時出現錯誤TypeError: 'Tensor' object is not callable
probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict)
Tensorflow Session.Run()張量對象不可調用
究竟導致此錯誤的原因是什麼?
這裏是代碼:
import numpy as np
import os
import tensorflow as tf
from ptb_word_lm import *
from tensorflow.models.rnn.ptb import reader
from tensorflow.python.platform import gfile
data_path = "/home/usr/simple-examples/data/"
raw_data = reader.ptb_raw_data(data_path)
train_data, valid_data, test_data, vocabulary = raw_data
test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id = reader._build_vocab(test_path)
eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1
sess = tf.Session()
initializer = tf.random_uniform_initializer(-eval_config.init_scale,
eval_config.init_scale)
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
with tf.variable_scope("model", reuse=None, initializer=initializer):
mtest = PTBModel(is_training=False, config=eval_config, input_=test_input)
sess.run(tf.initialize_all_variables())
saver = tf.train.import_meta_graph('/home/usr/models/medium/model.ckpt-50979.meta')
ckpt = tf.train.get_checkpoint_state('/home/usr/models/medium/')
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path
print(msg)
saver.restore(sess, ckpt.model_checkpoint_path)
def pick_from_weight(weight, pows=1.0):
weight = weight**pows
t = np.cumsum(weight)
s = np.sum(weight)
return int(np.searchsorted(t, np.random.rand(1) * s))
while True:
number_of_sentences = 10
sentence_cnt = 0
text = '\n'
end_of_sentence_char = word_to_id['<eos>']
input_char = np.array([[end_of_sentence_char]])
state = sess.run(mtest.initial_state)
for attr in mtest.__dict__:
print attr
print 'all attributes above'
while sentence_cnt < number_of_sentences:
feed_dict = {mtest._input: input_char,
mtest.initial_state: state}
probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict)
print 'after state'
sampled_char = pick_from_weight(probs[0])
print sampled_char
if sampled_char == end_of_sentence_char:
text += '.\n'
sentence_cnt += 1
else:
text += ' ' + id_to_word[sampled_char]
input_char = np.array([[sampled_char]])
print(text)
raw_input('press any key to continue ...')
當我做了兩個更改時,我得到了另一個錯誤:TypeError:無法將feed_dict關鍵字解釋爲張量:無法將PTBInput轉換爲張量。' – smith
這似乎是因爲'mtest._input' [實際上](https ://github.com/tensorflow/models/blob/520b557e095b008bfb023da1c749b3d0eabc521c/tutorials/rnn/ptb/ptb_word_lm.py#L101)你的'test_input'引用。我猜'mtest.input.input_data'可以工作。 – sunside
非常感謝。我仍然在代碼中進一步發現錯誤,但那是我堅持的路線。謝謝你的時間。 – smith