2013-12-10 67 views
1

我正在使用Sci-kit Learn的TruncatedSVD算法在稀疏矩陣上執行LSA。我希望轉換後的密集矩陣的數據類型爲float16而不是float64。注意:我不想在轉換之後更改數據類型 - 那時我的計算機將耗盡內存。我想TruncatedSVD.fit()直接返回float16類型的東西 - 我該怎麼做?獲取TruncatedSVD.transform()返回float16而不是float64

在應用轉換之前,我嘗試將原始稀疏矩陣和TruncatedSVD.components_更改爲float16,但輸出數據類型僅爲float32 - 這是一項改進,但並不完全符合我的要求。

+1

你的意思是'fit'還是'transform'? 我認爲Bitwise的回答是正確的。如果你設法讓它爲你工作,你可能需要考慮提交一個pull請求來添加'dtype'參數來控制'TruncatedSVD'精度。 – ogrisel

回答

3

查看代碼,TruncatedSVD使用as_float_array()sklearn.utilsas_float_array()的代碼是here

正如你看到的,文檔指出

新的D型將是np.float32或np.float64

我想你可以破解它是float16(或許刪除使用的as_float_array?),但我不知道會有什麼後果。

你應該考慮的一件事是在這些數值算法中使用較大的變量(例如float64)有助於數值穩定性。如果您正在處理非常大的矩陣,這一點尤其重要。如果你將使用float16,你可能會因數值問題而冒險得到不正確的結果。

相關問題