2017-08-02 51 views
0

我有一個相對於使用TensorArray的問題。如何在Tensorflow中使用tf.while_loop訪問TensorArray元素?

問題:
我想與一個tf.while_loop一個TensorArray的存取元件。請注意,我可以使用例如u1.read(0)來讀取TensorArray的內容。

我當前的代碼:
這是我到目前爲止有:

embeds_raw = tf.constant(np.array([ 
    [1, 1], 
    [1, 1], 
    [2, 2], 
    [3, 3], 
    [3, 3], 
    [3, 3] 
], dtype='float32')) 
embeds = tf.Variable(initial_value=embeds_raw) 
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable') 
sen_len = tf.placeholder('int32', shape=[None], name='sen_len') 
# max_l = tf.reduce_max(sen_len) 
current_size = tf.shape(sen_len)[0] 
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT') 
added_container_variable = tf.add(container_variable, padded_sen_len) 
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False) 
u1 = u1.split(embeds, added_container_variable) 

sentences = [] 
i = 0 

def condition(_i, _t_array): 
    return tf.less(_i, current_size) 

def body(_i, _t_array): 
    sentences.append(_t_array.read(_i)) 
    return _i + 1, _t_array 

idx, arr = tf.while_loop(condition, body, [i, u1]) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) 
    print(sents) 

錯誤消息:

Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 267, in init fetch, allow_tensor=True, allow_operation=True)) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2584, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2673, in _as_graph_element_locked % (type(obj).name, types_str)) TypeError: Can not convert a TensorArray into a Tensor or Operation.

在處理上述異常,發生其他異常:

Traceback (most recent call last): File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 191, in main() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 187, in main variable_container() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 179, in variable_container sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 984, in _run self._graph, fetches, feed_dict_string, feed_handles=feed_handles) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 410, in init self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 238, in for_fetch return _ElementFetchMapper(fetches, contraction_fn) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 271, in init % (fetch, type(fetch), str(e))) TypeError: Fetch argument has invalid type , must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)

回答

1

我沒有足夠的評論聲望,所以我會寫一個答案。

我不太明白你的代碼打算做什麼,但例外是因爲sess.run()返回Tensor s,而arrTensorArray。你可以這樣做,例如:

sents = sess.run(arr.concat(), feed_dict={sen_len: [2, 1, 3]}) 

當然,這只是解除你的分裂。如果你想獲得的所有值了,也許:

sents = sess.run([arr.read(i) for i in range(512)], feed_dict={sen_len: [2, 1, 3]}) 

但我敢肯定,一定有比硬編碼512想必你while_loop是爲了做一些更清潔的方式。