2016-08-18 67 views
7

Tensorflow教程包括使用tf.expand_dims將「批量維」添加到張量中。我已經閱讀了這個功能的文檔,但它對我仍然很神祕。有誰知道在什麼情況下必須使用?Tensorflow:何時使用tf.expand_dims?

我的代碼如下。我的意圖是根據預測箱和實際箱之間的距離來計算損失。 (例如,如果predictedBin = 10truthBin = 7,則binDistanceLoss = 3)。

batch_size = tf.size(truthValues_placeholder) 
labels = tf.expand_dims(truthValues_placeholder, 1) 
predictedBin = tf.argmax(logits) 
binDistanceLoss = tf.abs(tf.sub(labels, logits)) 

在這種情況下,我需要申請tf.expand_dimspredictedBinbinDistanceLoss?提前致謝。

回答

16

expand_dims將不添加或減少張量中的元素,它只是通過將1添加到尺寸來更改形狀。例如,具有10個元素的矢量可以被視爲10x1矩陣。

我遇到的情況使用expand_dims是當我試圖建立一個ConvNet分類灰度圖像。灰度圖像將作爲大小爲[320, 320]的矩陣加載。但是,tf.nn.conv2d需要輸入爲[batch, in_height, in_width, in_channels],其中in_channels維度在我的數據中丟失,在這種情況下應爲1。所以我用expand_dims來增加一個維度。

就你而言,我認爲你不需要expand_dims

6

要添加到Da Tong的答案,您可能需要同時擴展多個維度。例如,如果您正在對等級1的向量執行TensorFlow的conv1d操作,則需要爲它們提供等級三。

執行expand_dims幾次是可讀的,但可能會在計算圖中引入一些開銷。你可以用reshape得到一個班輪相同的功能:

import tensorflow as tf 

# having some tensor of rank 1, it could be an audio signal, a word vector... 
tensor = tf.ones(100) 
print(tensor.get_shape()) # => (100,) 

# expand its dimensionality to fit into conv2d 
tensor_expand = tf.expand_dims(tensor, 0) 
tensor_expand = tf.expand_dims(tensor_expand, 0) 
tensor_expand = tf.expand_dims(tensor_expand, -1) 
print(tensor_expand.get_shape()) # => (1, 1, 100, 1) 

# do the same in one line with reshape 
tensor_reshape = tf.reshape(tensor, [1, 1, tensor.get_shape().as_list()[0],1]) 
print(tensor_reshape.get_shape()) # => (1, 1, 100, 1) 

注:如果你得到的錯誤TypeError: Failed to convert object of type <type 'list'> to Tensor.,試圖通過tf.shape(x)[0]而不是x.get_shape()[0]的建議here

希望它有幫助!
乾杯,
安德烈斯

+0

你有沒有運行任何測試,看看是否在做一個'reshape'是不是做,比如說,兩個或三個'expand_dims'更快? – Nathan

+0

不是真的!我查看了[sources](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py),但無法理解gen_array_ops的位置,所以我可以說得不好......會對看到一些測試感興趣 –