2017-06-20 115 views
0

我收到以下錯誤:如何迭代tensorflow中的張量?

TypeError: 'Tensor' object is not iterable.

我試圖用一個佔位符和FIFOQueue養活數據。但是這裏的問題是我無法批量處理數據。任何人都可以提供解決方案嗎?

我是TensorFlow中的新成員,混淆了佔位符和張量的概念。

下面是代碼:

#-*- coding:utf-8 -*- 
import tensorflow as tf 
import sys 

q = tf.FIFOQueue(1000,tf.string) 
label_ph = tf.placeholder(tf.string,name="label") 
enqueue_op = q.enqueue_many(label_ph) 
qr = tf.train.QueueRunner(q,enqueue_op) 
m = q.dequeue() 

sess_conf = tf.ConfigProto() 
sess_conf.gpu_options.allow_growth = True 
sess = tf.Session(config=sess_conf) 
sess.run(tf.global_variables_initializer()) 
coord = tf.train.Coordinator() 
tf.train.start_queue_runners(coord=coord, sess=sess) 

image_batch = tf.train.batch(
     m,batch_size=3, 
     enqueue_many=True, 
     capacity=9 
     ) 

for i in range(0, 10): 
    print "-------------------------" 
    #print(sess.run(q.dequeue())) 
    a = ['a','b','c','a1','b1','c1','a','b','c2','a','b','c3',] 
    sess.run(enqueue_op,{label_ph:a}) 
    b = sess.run(m) 
    print b 
q.close() 
coord.request_stop() 

回答

0

我認爲你正在運行到同樣的問題我。當您運行會話時,您實際上無法訪問數據,您可以改爲訪問數據圖。所以你應該像圖中的節點那樣考慮張量對象,而不是像你可以做的事情那樣的大塊數據。如果你想對圖做些事情,你必須調用tf。*函數,或者在sess.run()調用中獲取變量。當你這樣做時,tensorflow會找出如何根據它的依賴關係獲取數據並運行計算。

至於你的問題,看看這個頁面上的QueueRunner例子。 https://www.tensorflow.org/programmers_guide/threading_and_queues

另一種方法,你可以做到這一點(這是我切換到)是你可以洗牌你的數據在CPU上,然後一次複製它。然後,您可以跟蹤自己正在進行的步驟,並獲取該步驟的數據。我幫助保持gpu數據並減少內存副本。

all_shape = [num_batches, batch_size, data_len] 
    local_shape = [batch_size, data_len] 

    ## declare the data variables 
    model.all_data = tf.Variable(tf.zeros(all_shape), dtype=tf.float32) 
    model.step_data=tf.Variable(tf.zeros(local_shape), dtype=tf.float32) 
    model.step = tf.Variable(0, dtype=tf.float32, trainable=False, name="step") 

    ## then for each step in the epoch grab the data 
    index = tf.to_int32(model.step) 
    model.step_data = model.all_data[index] 

    ## inc the step counter 
    tf.assign_add(model.step, 1.0, name="inc_counter") 
+0

我需要使用批處理。那麼你能否提供批量解決方案? – JerryWind

+0

我上面的代碼是一般的想法。您需要[num_batches,batch_size,data_len]的3D張量,然後爲每個批次抓取所需的切片。 – ReverseFall