2016-11-08 202 views
2

我的raw_data是PTB數據集。 我正在通過以下代碼生成批次。feed_dict中餵食問題(Tensorflow)

def generate_batches(raw_data, batch_size, unrollings): 
    global data_index 
    data_len = len(raw_data) 
    num_batches = data_len // batch_size 
    inputs = [] 
    labels = [] 
    print (num_batches, data_len, batch_size) 
    for j in xrange(unrollings) : 
     inputs.append([]) 
     labels.append([]) 
     for i in xrange(batch_size) : 
     inputs[j].append(raw_data[i + data_index]) 
     labels[j].append(raw_data[i + data_index + 1])  
     data_index = (data_index + batch_size) % len(raw_data) 
    return inputs, labels 

在會話運行中,生成的相同批生產飼料feed_dict,如以下代碼中所示。

for step in xrange(num_steps) : 
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5) 
feed_dict = dict() 
for i in range(unrollings): 
    feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels} 
    _, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict) 

的培訓投入和標籤如下:

for _ in range(unrollings) : 
train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32)) 
train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32)) 
train_inputs = train_data[:unrollings] 
train_labels = train_label[:unrollings] 

首先,我得到了錯誤TypeError: unhashable type: 'list'到我轉換batch_input列表使用tuple(batch_input[i])這是在Python dictionary : TypeError: unhashable type: 'list'解釋清楚元組。
解決:然後我得到這個錯誤TypeError: unhashable type: 'numpy.ndarray'

+1

您試圖使用ndarray作爲字典的關鍵字,它應該是字符串 –

+0

謝謝+1我更正了代碼。 – SupposeXYZ

回答

1

我想你是誤解feed_dict是如何工作的。但首先,python dict不接受任何不可關聯的類作爲關鍵的實例。 list和numpy.ndarray都不能用作dict鍵(即使你用一個元組包裝它)。我發現list post解釋關於字典的關鍵。

feed_dict如何工作

在圖形中,應該有象徵意義的張量創建佔位符。假設您的原始數據是2D:(num_samples,num_features),第一個維度對應於樣本的大小,第二個維度對應於特徵的數量。假設標籤是一種熱門編碼,並且總共有num_classes。

train_data = tf.placeholder(shape=[batch_size, num_features], dtype=tf.float32) 
train_labels = tf.placeholder(shape=[batch_size, num_classes], dtype=tf.float32) 
在會話建立feed_dict時

然後,使用這些符號佔位張量的關鍵和採樣batch_data的價值。

feed_dict = {train_data:batch_inputs, train_labels:batch_labels}