0
下面的代碼創建了一些隱式的張量。我不知道如何可以查看這些張量的值:在張量流中訪問隱式張量
<tf.Variable 'rnn/basic_lstm_cell/kernel:0' shape=(43, 160) dtype=float32_ref>
<tf.Variable 'rnn/basic_lstm_cell/bias:0' shape=(160,) dtype=float32_ref>
<tf.Variable 'rnn/basic_lstm_cell/kernel/Adagrad:0' shape=(43, 160) dtype=float32_ref>
<tf.Variable 'rnn/basic_lstm_cell/bias/Adagrad:0' shape=(160,) dtype=float32_ref>
<tf.Variable 'softmax/W/Adagrad:0' shape=(40, 10) dtype=float32_ref>
<tf.Variable 'softmax/b/Adagrad:0' shape=(10,) dtype=float32_ref>
這裏是代碼本身。
import tensorflow as tf
import numpy as np
VECTOR_SIZE = 3
SEQUENCE_LENGTH = 5
BATCH_SIZE = 7
STATE_SIZE = 40
NUM_CLASSES = 10
LEARNING_RATE = 0.1
x = tf.placeholder(tf.float32, [BATCH_SIZE, SEQUENCE_LENGTH, VECTOR_SIZE],
name='input_placeholder')
y = tf.placeholder(tf.int32, [BATCH_SIZE, SEQUENCE_LENGTH],
name='labels_placeholder')
init_state = tf.zeros([BATCH_SIZE, STATE_SIZE])
rnn_inputs = tf.unstack(x, axis = 1)
y_as_list = tf.unstack(y, axis=1)
cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs,
initial_state=(init_state,init_state))
with tf.variable_scope('softmax'):
W = tf.get_variable('W', [STATE_SIZE, NUM_CLASSES])
b = tf.get_variable('b', [NUM_CLASSES], initializer=tf.constant_initializer(0.0))
logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]
predictions = [tf.nn.softmax(logit) for logit in logits]
losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logit) for \
logit, label in zip(logits, y_as_list)]
total_loss = tf.reduce_mean(losses)
train_step = tf.train.AdagradOptimizer(LEARNING_RATE).minimize(total_loss)
X = np.ones([BATCH_SIZE, SEQUENCE_LENGTH, VECTOR_SIZE])
Y = np.ones([BATCH_SIZE, SEQUENCE_LENGTH])
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
batch_total_loss = sess.run([total_loss, train_step],
feed_dict = {x:X,y:Y})
save_path = saver.save(sess, "/tmp/model.ckpt")
for el in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
print(el)