我有標籤的NumPy的數組:IndexError索引的二維數組與一維數組(NumPy的)
labels = np.ndarray(10000, dtype=np.float32)
在數組中的元素看起來像:
print(labels[1:5])
Output: [ 9. 9. 4. 1.]
我想將它們轉換成一個熱編碼的標籤,我用下面的代碼:
one_hot_labels = np.eye(10)[labels]
我得到以下錯誤:
IndexError Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
1
----> 2 s=np.eye(10)[labels]
IndexError: arrays used as indices must be of integer (or boolean) type
我該如何解決這個問題?
你確定標籤和火車標籤是一樣的嗎? –
你需要使用整數值作爲索引:'one_hot_labels = np.eye(10)[labels.astype(int)]' – JohanL
@JohanL謝謝。它的工作原理 – Jayanth