我有一個相對於使用TensorArray的問題。如何在Tensorflow中使用tf.while_loop訪問TensorArray元素?
問題:
我想與一個tf.while_loop
一個TensorArray的存取元件。請注意,我可以使用例如u1.read(0)
來讀取TensorArray的內容。
我當前的代碼:
這是我到目前爲止有:
embeds_raw = tf.constant(np.array([
[1, 1],
[1, 1],
[2, 2],
[3, 3],
[3, 3],
[3, 3]
], dtype='float32'))
embeds = tf.Variable(initial_value=embeds_raw)
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable')
sen_len = tf.placeholder('int32', shape=[None], name='sen_len')
# max_l = tf.reduce_max(sen_len)
current_size = tf.shape(sen_len)[0]
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT')
added_container_variable = tf.add(container_variable, padded_sen_len)
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False)
u1 = u1.split(embeds, added_container_variable)
sentences = []
i = 0
def condition(_i, _t_array):
return tf.less(_i, current_size)
def body(_i, _t_array):
sentences.append(_t_array.read(_i))
return _i + 1, _t_array
idx, arr = tf.while_loop(condition, body, [i, u1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]})
print(sents)
錯誤消息:
Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 267, in init fetch, allow_tensor=True, allow_operation=True)) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2584, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2673, in _as_graph_element_locked % (type(obj).name, types_str)) TypeError: Can not convert a TensorArray into a Tensor or Operation.
在處理上述異常,發生其他異常:
Traceback (most recent call last): File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 191, in main() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 187, in main variable_container() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 179, in variable_container sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 984, in _run self._graph, fetches, feed_dict_string, feed_handles=feed_handles) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 410, in init self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 238, in for_fetch return _ElementFetchMapper(fetches, contraction_fn) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 271, in init % (fetch, type(fetch), str(e))) TypeError: Fetch argument has invalid type , must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)