2015-08-14 30 views
9

我有一個整數類標籤的字節張量,例如來自MNIST數據集。在火炬中如何從整數標籤列表創建單熱張量?

1 
7 
5 
[torch.ByteTensor of size 3] 

如何使用它來創建1熱矢量的張量?

1 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 1 0 0 0 
0 0 0 0 1 0 0 0 0 0 
[torch.DoubleTensor of size 3x10] 

我知道我可以用一個循環做到這一點,但我不知道是否有任何聰明的火炬索引,將讓這對我在一個單一的線。

回答

13
indices = torch.LongTensor{1,7,5}:view(-1,1) 
one_hot = torch.zeros(3, 10) 
one_hot:scatter(2, indices, 1) 

你可以找到在torch/torch7 github readmescatter的文件(在主分支)。

2

的另一種方法是從單位矩陣洗牌行:

indicies = torch.LongTensor{1,7,5} 
one_hot = torch.eye(10):index(1, indicies) 

這不是我的主意,我發現它在karpathy/char-rnn