4

我想了解Spark OneHotEncoder默認放棄最後一個類別的理由。爲什麼Spark的OneHotEncoder默認放棄最後一個類別?

例如:

>>> fd = spark.createDataFrame([(1.0, "a"), (1.5, "a"), (10.0, "b"), (3.2, "c")], ["x","c"]) 
>>> ss = StringIndexer(inputCol="c",outputCol="c_idx") 
>>> ff = ss.fit(fd).transform(fd) 
>>> ff.show() 
+----+---+-----+ 
| x| c|c_idx| 
+----+---+-----+ 
| 1.0| a| 0.0| 
| 1.5| a| 0.0| 
|10.0| b| 1.0| 
| 3.2| c| 2.0| 
+----+---+-----+ 

默認情況下,OneHotEncoder將下降的最後一類:

>>> oe = OneHotEncoder(inputCol="c_idx",outputCol="c_idx_vec") 
>>> fe = oe.transform(ff) 
>>> fe.show() 
+----+---+-----+-------------+ 
| x| c|c_idx| c_idx_vec| 
+----+---+-----+-------------+ 
| 1.0| a| 0.0|(2,[0],[1.0])| 
| 1.5| a| 0.0|(2,[0],[1.0])| 
|10.0| b| 1.0|(2,[1],[1.0])| 
| 3.2| c| 2.0| (2,[],[])| 
+----+---+-----+-------------+ 

當然,這種行爲是可以改變的:

>>> oe.setDropLast(False) 
>>> fl = oe.transform(ff) 
>>> fl.show() 
+----+---+-----+-------------+ 
| x| c|c_idx| c_idx_vec| 
+----+---+-----+-------------+ 
| 1.0| a| 0.0|(3,[0],[1.0])| 
| 1.5| a| 0.0|(3,[0],[1.0])| 
|10.0| b| 1.0|(3,[1],[1.0])| 
| 3.2| c| 2.0|(3,[2],[1.0])| 
+----+---+-----+-------------+ 

問題::

  • 在什麼情況下默認行爲是理想的?
  • 盲目致電setDropLast(False)可能會忽略哪些問題?
  • 作者在文檔中的含義如下:

最後一類不是默認(經由dropLast配置),因爲它使矢量條目總結爲一個,並因此線性依賴包括在內。

+4

我建議你搜索文獻/關於'虛擬變量文章(和線性迴歸) – Aeck

+0

@Aeck謝謝!看起來像虛擬變量陷阱絕對是這個答案問題,如果有人在意寫一點關於它的... – Corey

回答

1

根據該文檔它是保持柱獨立:

一個一熱編碼器類索引的列映射到二進制矢量的一列 ,具有至多一個單一一個 表示輸入類別索引。例如,對於5個類別,輸入值爲2.0的 將映射到[0.0,0.0,1.0, 0.0]的輸出向量。默認情況下不包括最後一個類別(可通過OneHotEncoder!.dropLast進行配置,因爲它使向量條目總和最大爲 ,因此線性相關,因此輸入值4.0映射到 [0.0,0.0,0.0,0.0]請注意,這是從不同scikit學習的 OneHotEncoder,這使所有類別。輸出向量 稀疏。

https://spark.apache.org/docs/1.5.2/api/java/org/apache/spark/ml/feature/OneHotEncoder.html

+1

哈哈,是最懶惰和願意寫* *的點*。如果有人查找這個答案,這裏有一些更多的信息。分類特徵導致有效的攔截。如果您包含一般攔截詞,那麼最小化者可以添加例如0.5到截距和-0.5到所有類別以獲得相同的成本函數值。爲了避免這種退化,請刪除攔截幷包含所有類別。 – Corey

+0

除此之外: 對於Scala api,在邏輯迴歸分類器中使用'.setFitIntercept(false)'來刪除包含所有類別時的攔截! – aMKa

相關問題