2016-11-20 64 views
1

original Tensorflow tutorial包括以下代碼:Tensorflow tf.expand_dims

batch_size = tf.size(labels) 
labels = tf.expand_dims(labels, 1) 
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1) 
concated = tf.concat(1, [indices, labels]) 
onehot_labels = tf.sparse_to_dense(concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0) 

第二行增加了尺寸的labels張量。然而,labels通過飼料字典餵養,所以它應該已經有形狀[batch_size, NUM_CLASSES]。如果是這樣,那麼爲什麼expand_dims在這裏使用?

回答

2

這篇教程很舊。您引用的是版本0.6,而它們在(本帖子的11-20-2016時間)時爲0.11。所以當時有很多不同的功能v0.6。

反正回答你的問題:

在MNIST標籤剛剛編碼爲數字0-9。然而,損失函數期望標籤被編碼爲一個熱點矢量。

標籤不是[batch_size, NUM_CLASSES]在那個例子中它只是[batch_size]

這可以通過類似的numpy函數來完成。此外,他們還提供了從張力流中的mnist數據集中獲取標籤的功能,作爲一個已經具有所述形狀的熱矢量。