2016-01-17 66 views
8

使用星火1.5和Scala 2.10.6與行字段過濾器火花數據幀是字符串

我試圖通過現場「標籤」這是一個字符串數組來過濾數據幀的數組。尋找具有'私人'標籤的所有行。

val report = df.select("*") 
    .where(df("tags").contains("private")) 

越來越:

異常線程 「main」 org.apache.spark.sql.AnalysisException: 無法解析 '包含(標籤專用)' 由於數據類型不匹配: 說法1需要字符串類型,但是,「標籤」是數組 類型。

過濾方法更適合嗎?

更新:

數據從卡桑德拉適配器,但一個小例子來顯示我想要做的,也得到了上面的錯誤是:

def testData (sc: SparkContext): DataFrame = { 
    val stringRDD = sc.parallelize(Seq(""" 
     { "name": "ed", 
     "tags": ["red", "private"] 
     }""", 
     """{ "name": "fred", 
     "tags": ["public", "blue"] 
     }""") 
    ) 
    val sqlContext = new org.apache.spark.sql.SQLContext(sc) 
    import sqlContext.implicits._ 
    sqlContext.read.json(stringRDD) 
    } 
    def run(sc: SparkContext) { 
    val df1 = testData(sc) 
    df1.show() 
    val report = df1.select("*") 
     .where(df1("tags").contains("private")) 
    report.show() 
    } 

更新:標籤陣列可以是任何長度和「私人」標籤可以是在任何位置

更新:一個解決方案,它的工作原理:UDF

val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")} 
val report = df1.filter(filterPriv(df1("tags"))) 
+0

發佈您的數據樣本以及如何創建df –

+1

一種選擇是構建UDF。 –

+1

那麼,查看源代碼後(因爲'Column.contains'的scaladoc只說「包含其他元素」,這不是很有啓發性),我看到'Column.contains'構造了一個'org.apache的實例.spark.sql.catalyst.expressions.Contains'它說「一個函數,如果字符串'left'包含字符串'right',則返回true。所以看起來'df1(「tags」)。contains'在這種情況下無法做到我們想要的。我不知道有什麼替代建議。 '...表達式中也有一個'ArrayContains',但'Column'似乎沒有使用它。 –

回答

13

我想如果你使用where(array_contains(...))它將工作。這裏是我的結果:

scala> import org.apache.spark.SparkContext 
import org.apache.spark.SparkContext 

scala> import org.apache.spark.sql.DataFrame 
import org.apache.spark.sql.DataFrame 

scala> def testData (sc: SparkContext): DataFrame = { 
    |  val stringRDD = sc.parallelize(Seq 
    |  ("""{ "name": "ned", "tags": ["blue", "big", "private"] }""", 
    |  """{ "name": "albert", "tags": ["private", "lumpy"] }""", 
    |  """{ "name": "zed", "tags": ["big", "private", "square"] }""", 
    |  """{ "name": "jed", "tags": ["green", "small", "round"] }""", 
    |  """{ "name": "ed", "tags": ["red", "private"] }""", 
    |  """{ "name": "fred", "tags": ["public", "blue"] }""")) 
    |  val sqlContext = new org.apache.spark.sql.SQLContext(sc) 
    |  import sqlContext.implicits._ 
    |  sqlContext.read.json(stringRDD) 
    | } 
testData: (sc: org.apache.spark.SparkContext)org.apache.spark.sql.DataFrame 

scala> 
    | val df = testData (sc) 
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>] 

scala> val report = df.select ("*").where (array_contains (df("tags"), "private")) 
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>] 

scala> report.show 
+------+--------------------+ 
| name|    tags| 
+------+--------------------+ 
| ned|[blue, big, private]| 
|albert| [private, lumpy]| 
| zed|[big, private, sq...| 
| ed|  [red, private]| 
+------+--------------------+ 

需要注意的是,如果你寫where(array_contains(df("tags"), "private"))的工作,但如果你寫where(df("tags").array_contains("private"))(更直接類似於您最初寫的東西),它失敗array_contains is not a member of org.apache.spark.sql.Column。看看Column的源代碼,我看到有一些東西需要處理contains(爲此構建一個Contains實例),但不是array_contains。也許這是一個疏忽。

0

您可以使用序號來引用json數組,例如在你的情況下df("tags")(0)。這裏是工作示例

scala> val stringRDD = sc.parallelize(Seq(""" 
    |  { "name": "ed", 
    |   "tags": ["private"] 
    |  }""", 
    |  """{ "name": "fred", 
    |   "tags": ["public"] 
    |  }""") 
    | ) 
stringRDD: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[87] at parallelize at <console>:22 

scala> import sqlContext.implicits._ 
import sqlContext.implicits._ 

scala> sqlContext.read.json(stringRDD) 
res28: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>] 

scala> val df=sqlContext.read.json(stringRDD) 
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>] 

scala> df.columns 
res29: Array[String] = Array(name, tags) 

scala> df.dtypes 
res30: Array[(String, String)] = Array((name,StringType), (tags,ArrayType(StringType,true))) 

scala> val report = df.select("*").where(df("tags")(0).contains("private")) 
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>] 

scala> report.show 
+----+-------------+ 
|name|   tags| 
+----+-------------+ 
| ed|List(private)| 
+----+-------------+ 
+0

謝謝。如果pos是固定的,但不是。我應該讓測試數據更復雜一些,數組中可以有任意數量的標籤,位置是任意的。 – navicore

+0

@navicore那麼你的代碼應該工作。檢查我的更新 –

+0

有趣,我錯過了一些東西,看起來像我正在做的,但得到的錯誤。雙重檢查火花版本現在... – navicore