您可以從最初的SparseTensor構建另一個形狀爲(batch_size, num_classes)
的SparseTensor。例如,如果你把你的班級在一個字符串特徵柱(用空格隔開),可以使用下列內容:
import tensorflow as tf
all_classes = ["class1", "class2", "class3"]
classes_column = ["class1 class3", "class1 class2", "class2", "class3"]
table = tf.contrib.lookup.index_table_from_tensor(
mapping=tf.constant(all_classes)
)
classes = tf.constant(classes_column)
classes = tf.string_split(classes)
idx = table.lookup(classes) # SparseTensor of shape (4, 2), because each of the 4 rows has at most 2 classes
num_items = tf.cast(tf.shape(idx)[0], tf.int64) # num items in batch
num_entries = tf.shape(idx.indices)[0] # num nonzero entries
y = tf.SparseTensor(
indices=tf.stack([idx.indices[:, 0], idx.values], axis=1),
values=tf.ones(shape=(num_entries,), dtype=tf.int32),
dense_shape=(num_items, len(all_classes)),
)
y = tf.sparse_tensor_to_dense(y, validate_indices=False)
with tf.Session() as sess:
tf.tables_initializer().run()
print(sess.run(y))
# Outputs:
# [[1 0 1]
# [1 1 0]
# [0 1 0]
# [0 0 1]]
這裏idx
是SparseTensor。其索引idx.indices[:, 0]
的第一列包含批次的行號,其值idx.values
包含相關類ID的索引。我們結合這兩個來創建新的y.indices
。
要全面實施多標籤分類,請參見https://stackoverflow.com/a/47671503/507062的「選項2」