2017-07-06 100 views
1

我正在研究一個程序,我需要根據特定條件在數據集中顯示特定行。這些條件適用於我爲機器學習模型創建的features列。這features列是一個向量列,當我試圖通過傳遞一個矢量值來過濾它,我得到以下錯誤:過濾矢量類型的「特徵」列

Exception in thread "main" java.lang.RuntimeException: Unsupported literal type class org.apache.spark.ml.linalg.DenseVector at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:75) at org.apache.spark.sql.functions$.lit(functions.scala:101)

這是給我的錯誤過濾部分:

dataset.where(dataset.col("features").notEqual(datapoint)); //datapoint is a Vector

有沒有辦法解決這個問題?

回答

1

您需要爲Vector創建一個udf進行過濾。以下爲我工作:

import org.apache.spark.ml.feature.VectorAssembler 
import org.apache.spark.ml.linalg.Vectors 
import org.apache.spark.sql.functions.udf 

val df = sc.parallelize(Seq(
    (1, 1, 1), (1, 2, 3), (1, 3, 5), (2, 4, 6), 
    (2, 5, 2), (2, 6, 1), (3, 7, 5), (3, 8, 16), 
    (1, 1, 1))).toDF("c1", "c2", "c3") 

val dfVec = new VectorAssembler() 
    .setInputCols(Array("c1", "c2", "c3")) 
    .setOutputCol("features") 
    .transform(df) 

def vectors_unequal(vec1: Vector) = udf((vec2: Vector) => !vec1.equals(vec2)) 

val vecToRemove = Vectors.dense(1,1,1) 

val filtered = dfVec.where(vectors_unequal(vecToRemove)(dfVec.col("features"))) 
val filtered2 = dfVec.filter(vectors_unequal(vecToRemove)($"features")) // Also possible 

dfVec show產量:

+---+---+---+--------------+ 
| c1| c2| c3|  features| 
+---+---+---+--------------+ 
| 1| 1| 1| [1.0,1.0,1.0]| 
| 1| 2| 3| [1.0,2.0,3.0]| 
| 1| 3| 5| [1.0,3.0,5.0]| 
| 2| 4| 6| [2.0,4.0,6.0]| 
| 2| 5| 2| [2.0,5.0,2.0]| 
| 2| 6| 1| [2.0,6.0,1.0]| 
| 3| 7| 5| [3.0,7.0,5.0]| 
| 3| 8| 16|[3.0,8.0,16.0]| 
| 1| 1| 1| [1.0,1.0,1.0]| 
+---+---+---+--------------+ 

filtered show產量:

+---+---+---+--------------+ 
| c1| c2| c3|  features| 
+---+---+---+--------------+ 
| 1| 2| 3| [1.0,2.0,3.0]| 
| 1| 3| 5| [1.0,3.0,5.0]| 
| 2| 4| 6| [2.0,4.0,6.0]| 
| 2| 5| 2| [2.0,5.0,2.0]| 
| 2| 6| 1| [2.0,6.0,1.0]| 
| 3| 7| 5| [3.0,7.0,5.0]| 
| 3| 8| 16|[3.0,8.0,16.0]| 
+---+---+---+--------------+