2016-02-05 21 views
4

大約有計算一個熱的嵌入與TensorFlow幾個堆棧溢出的問題,這裏是接受的解決方案:這是TensorFlow中的一種熱門編碼嗎?或因任何原因有缺陷?

num_labels = 10 
sparse_labels = tf.reshape(label_batch, [-1, 1]) 
derived_size = tf.shape(label_batch)[0] 
indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1]) 
concated = tf.concat(1, [indices, sparse_labels]) 
outshape = tf.reshape(tf.concat(0, [derived_size, [num_labels]]), [-1]) 
labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0) 

這是一個官方教程幾乎相同的代碼:https://www.tensorflow.org/versions/0.6.0/tutorials/mnist/tf/index.html

要我似乎自從tf.nn.embedding_lookup存在,它可能更有效率。這是一個使用這個版本,它支持任意形狀的輸入:

def one_hot(inputs, num_classes): 
    with tf.device('/cpu:0'): 
     table = tf.constant(np.identity(num_classes, dtype=np.float32)) 
     embeddings = tf.nn.embedding_lookup(table, inputs) 
    return embeddings 

您是否期望此實現更快?是否有其他原因存在缺陷?

+0

'劣質'是一種主觀的質量。你能以客觀的方式表達出來嗎?例如時間,記憶,產生錯誤;可以衡量的東西。 –

+1

你對TensorFlow開發人員的要求是顯而易見的,但這不是我的要求。我發現TensorFlow的例子很實用:不止一次,我認爲我正在改進一些東西,後來認識到他們在設計時非常小心(儘管缺乏文檔)。對我而言,這種單熱編碼器更好(更可讀,更通用,可能更快),但是我要求以我沒有看到的方式查看它是否有缺陷。 – rd11

+0

我明白。我添加了評論,因爲StackOverflow對問題有特定的要求,如果新人看到這個問題,他們可能會開始提出更多不被允許的主觀問題。 TensorFlow標籤更加寬容,但如果標準沒有保留,標籤將變得毫無價值,我想長期使用它。 –

回答

18

您的問題中的one_hot()函數看起來正確。但是,我們不建議用這種方式編寫代碼的原因在於它的內存效率很低,爲。爲了理解爲什麼,假設您的批量大小爲32和1,000,000個類。

  • 在本教程中提出的版本,最大的張量將成爲tf.sparse_to_dense()的結果,這將是32 x 1000000

  • one_hot()函數的問題中,最大的張量將是np.identity(1000000)的結果,即4TB。當然,分配張量可能不會成功。即使類的數量要小得多,它仍然會浪費內存來顯式存儲所有這些零。— TensorFlow不會自動將您的數據轉換爲稀疏表示,即使這樣做可能是有利的。

最後,我想爲最近添加到開源存儲庫的新功能提供插件,並將在下一版本中提供。 tf.nn.sparse_softmax_cross_entropy_with_logits()允許您指定一個整數向量作爲標籤,並且不需要構建密集的單熱表示。它應該更有效率地爲解決大量的類。

+2

由mrry的回答提醒,我已經改變了現有的MNIST convolutional.py和CIFAR模型以使用這個新的操作,所以它們作爲一個例子很有用:https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/models/image/cifar10/cifar10.py#L270 – dga

+0

鏈接無法正常工作。 – LoveMeow

+0

應該再次工作! – mrry

相關問題