2015-05-14 212 views
23

我開始使用Spark Dataframes,並且我需要能夠透視數據以創建具有多行的1列中的多列。在Scalding中有內置的功能,我相信python中的Pandas,但我找不到新的Spark Dataframe。如何轉動Spark DataFrame?

我認爲我可以編寫自定義的函數來完成這個任務,但我甚至不知道如何開始,特別是因爲我是Spark的新手。我有任何人知道如何做到這一點內置的功能或如何在Scala中寫入東西的建議,非常感謝。

+0

看到這個[類似的問題](http://stackoverflow.com/questions/30260015/reshaping-pivoting-data-in-spark-rdd-and-or-spark-dataframes/)其中,I貼本地Spark方法,不需要提前知道列/類別名稱。 – patricksurry

回答

27

As mentioned作者:@user2000823 Spark從版本1.6開始提供pivot函數。一般語法如下所示:使用nycflights13

df 
    .groupBy(grouping_columns) 
    .pivot(pivot_column, [values]) 
    .agg(aggregate_expressions) 

使用示例和csv格式:

的Python

from pyspark.sql.functions import avg 

flights = (sqlContext 
    .read 
    .format("csv") 
    .options(inferSchema="true", header="true") 
    .load("flights.csv") 
    .na.drop()) 

flights.registerTempTable("flights") 
sqlContext.cacheTable("flights") 

gexprs = ("origin", "dest", "carrier") 
aggexpr = avg("arr_delay") 

flights.count() 
## 336776 

%timeit -n10 flights.groupBy(*gexprs).pivot("hour").agg(aggexpr).count() 
## 10 loops, best of 3: 1.03 s per loop 

斯卡拉

val flights = sqlContext 
    .read 
    .format("csv") 
    .options(Map("inferSchema" -> "true", "header" -> "true")) 
    .load("flights.csv") 

flights 
    .groupBy($"origin", $"dest", $"carrier") 
    .pivot("hour") 
    .agg(avg($"arr_delay")) 

爪哇

import static org.apache.spark.sql.functions.*; 
import org.apache.spark.sql.*; 

Dataset<Row> df = spark.read().format("csv") 
     .option("inferSchema", "true") 
     .option("header", "true") 
     .load("flights.csv"); 

df.groupBy(col("origin"), col("dest"), col("carrier")) 
     .pivot("hour") 
     .agg(avg(col("arr_delay"))); 

R/SparkR

library(magrittr) 

flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE) 

flights %>% 
    groupBy("origin", "dest", "carrier") %>% 
    pivot("hour") %>% 
    agg(avg(column("arr_delay"))) 

R/sparklyr

library(dplyr) 

flights <- spark_read_csv(sc, "flights", "flights.csv") 

avg.arr.delay <- function(gdf) { 
    expr <- invoke_static(
     sc, 
     "org.apache.spark.sql.functions", 
     "avg", 
     "arr_delay" 
    ) 
    gdf %>% invoke("agg", expr, list()) 
} 

flights %>% 
    sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay) 

實施例數據

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00 

性能考慮

一般來說樞轉是昂貴的操作。

11

我通過編寫for循環來動態創建SQL查詢,從而克服了這個問題。說我有:

id tag value 
1 US 50 
1 UK 100 
1 Can 125 
2 US 75 
2 UK 150 
2 Can 175 

,我想:

id US UK Can 
1 50 100 125 
2 75 150 175 

我可以創建我想轉動,然後創建一個包含SQL查詢我需要一個字符串值的列表。

val countries = List("US", "UK", "Can") 
val numCountries = countries.length - 1 

var query = "select *, " 
for (i <- 0 to numCountries-1) { 
    query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " 
} 
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable" 

myDataFrame.registerTempTable("myTable") 
val myDF1 = sqlContext.sql(query) 

我可以創建類似的查詢然後做聚合。這不是一個非常優雅的解決方案,但它可以工作,並且對於任何值列表都是靈活的,當您調用代碼時也可以作爲參數傳入。

+0

我想重現你的例子,但我得到了一個「org.apache.spark.sql.AnalysisException:無法解析'US'給定輸入列id,標記,值」 – user299791

+0

這與引號有關。如果您查看生成的文本字符串,您會得到'case when tag = US',因此Spark認爲它是列名而不是文本值。你真正想看到的是'case when tag ='US'''。我已經編輯了上述答案,以便正確設置引號。 –

+2

但也如上所述,這是使用pivot命令的功能現在是Spark的本地功能。 –

5

我已經解決了使用dataframes下面的步驟類似的問題:

爲您的所有國家都列有「價值」的價值:

import org.apache.spark.sql.functions._ 
val countries = List("US", "UK", "Can") 
val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) => 
    if(countryToCheck == countryInRow) value else 0 
} 
val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) } 
val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value") 

你的數據框「dfWithCountries」會看像這樣:

+--+--+---+---+ 
|id|US| UK|Can| 
+--+--+---+---+ 
| 1|50| 0| 0| 
| 1| 0|100| 0| 
| 1| 0| 0|125| 
| 2|75| 0| 0| 
| 2| 0|150| 0| 
| 2| 0| 0|175| 
+--+--+---+---+ 

現在你可以爲你想要的結果加在一起的所有值:

dfWithCountries.groupBy("id").sum(countries: _*).show 

結果:

+--+-------+-------+--------+ 
|id|SUM(US)|SUM(UK)|SUM(Can)| 
+--+-------+-------+--------+ 
| 1|  50| 100|  125| 
| 2|  75| 150|  175| 
+--+-------+-------+--------+ 

這並不是雖然很優雅的解決方案。我不得不創建一系列函數來添加所有列。另外,如果我有很多國家,我會將我的臨時數據集擴大到很多零。

0

最初我採用了Al M的解決方案。後來採取了同樣的想法,並重寫了這個功能作爲轉置功能。

此方法調換任何DF行到任何數據格式的列使用鍵和值的列

輸入CSV

id,tag,value 
1,US,50a 
1,UK,100 
1,Can,125 
2,US,75 
2,UK,150 
2,Can,175 

輸出中

+--+---+---+---+ 
|id| UK| US|Can| 
+--+---+---+---+ 
| 2|150| 75|175| 
| 1|100|50a|125| 
+--+---+---+---+ 

轉置方法:

def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = { 

val distinctCols = df.select(key).distinct.map { r => r(0) }.collect().toList 

val rdd = df.map { row => 
(compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] }, 
scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any])) 
} 
val pairRdd = rdd.reduceByKey(_ ++ _) 
val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols)) 
hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols))) 

} 

private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = { 
val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) } 
val array = r._1 ++ cols 
Row(array: _*) 
} 

private def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = { 
val idSchema = idCols.map { idCol => srcSchema.apply(idCol) } 
val colSchema = srcSchema.apply(distinctCols._1) 
val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) } 
StructType(idSchema ++ colsSchema) 
} 

主要片段

import java.util.Date 
import org.apache.spark.SparkConf 
import org.apache.spark.SparkContext 
import org.apache.spark.sql.Row 
import org.apache.spark.sql.DataFrame 
import org.apache.spark.sql.types.StructType 
import org.apache.spark.sql.hive.HiveContext 
import org.apache.spark.sql.types.StructField 


... 
... 
def main(args: Array[String]): Unit = { 

    val sc = new SparkContext(conf) 
    val sqlContext = new org.apache.spark.sql.SQLContext(sc) 
    val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true") 
    .load("data.csv") 
    dfdata1.show() 
    val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value") 
    dfOutput.show 

} 
+2

此方法將行轉換爲列... – Jaigates

0

有簡單而優雅的解決方案。

scala> spark.sql("select * from k_tags limit 10").show() 
+---------------+-------------+------+ 
|   imsi|   name| value| 
+---------------+-------------+------+ 
|246021000000000|   age| 37| 
|246021000000000|  gender|Female| 
|246021000000000|   arpu| 22| 
|246021000000000| DeviceType| Phone| 
|246021000000000|DataAllowance| 6GB| 
+---------------+-------------+------+ 

scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show() 
+---------------+-------------+----------+---+----+------+ 
|   imsi|DataAllowance|DeviceType|age|arpu|gender| 
+---------------+-------------+----------+---+----+------+ 
|246021000000000|   6GB|  Phone| 37| 22|Female| 
|246021000000001|   1GB|  Phone| 72| 10| Male| 
+---------------+-------------+----------+---+----+------+ 
相關問題