2017-07-18 79 views
0

我想問一下tf.one_hot()函數是否支持SparseTensor作爲「indices」參數。我想要做一個多標籤分類(每個例子都有多個標籤),這需要計算一個cross-trape丟失。tf.one_hot()是否支持SparseTensor作爲索引參數?

我試圖直接把SparseTensor在「指數」參數,但它提出了以下錯誤:

類型錯誤:未能類型的對象轉換爲張量。內容:SparseTensor(indices = Tensor(「read_batch_features/fifo_queue_dequeue:106」,shape =(?, 2),dtype = int64,device =/job:worker),values = Tensor(「string_to_index_Lookup:0」,shape =(? ,),dtype = int64,device =/job:worker),dense_shape = Tensor(「read_batch_features/fifo_queue_dequeue:108」,shape =(2,),dtype = int64,device =/job:worker))。考慮將元素轉換爲支持的類型。

任何有關可能原因的建議?

謝謝。

回答

0

one_hot不支持SparseTensor作爲indices參數。您可以通過稀疏張量的索引/值張量作爲索引參數,這可能會解決您的問題。

0

您可以從最初的SparseTensor構建另一個形狀爲(batch_size, num_classes)的SparseTensor。例如,如果你把你的班級在一個字符串特徵柱(用空格隔開),可以使用下列內容:

import tensorflow as tf 

all_classes = ["class1", "class2", "class3"] 
classes_column = ["class1 class3", "class1 class2", "class2", "class3"] 

table = tf.contrib.lookup.index_table_from_tensor(
    mapping=tf.constant(all_classes) 
) 
classes = tf.constant(classes_column) 
classes = tf.string_split(classes) 
idx = table.lookup(classes) # SparseTensor of shape (4, 2), because each of the 4 rows has at most 2 classes 
num_items = tf.cast(tf.shape(idx)[0], tf.int64) # num items in batch 
num_entries = tf.shape(idx.indices)[0] # num nonzero entries 

y = tf.SparseTensor(
    indices=tf.stack([idx.indices[:, 0], idx.values], axis=1), 
    values=tf.ones(shape=(num_entries,), dtype=tf.int32), 
    dense_shape=(num_items, len(all_classes)), 
) 
y = tf.sparse_tensor_to_dense(y, validate_indices=False) 

with tf.Session() as sess: 
    tf.tables_initializer().run() 
    print(sess.run(y)) 

    # Outputs: 
    # [[1 0 1] 
    # [1 1 0] 
    # [0 1 0] 
    # [0 0 1]] 

這裏idx是SparseTensor。其索引idx.indices[:, 0]的第一列包含批次的行號,其值idx.values包含相關類ID的索引。我們結合這兩個來創建新的y.indices

要全面實施多標籤分類,請參見https://stackoverflow.com/a/47671503/507062的「選項2」