你可以用一點與NumPy和RDDS的嘗試。首先一堆進口:
from operator import itemgetter
import numpy as np
from pyspark.statcounter import StatCounter
讓我們來定義幾個變量:
keys = ["key1", "key2", "key3"] # list of key column names
xs = ["x1", "x2", "x3"] # list of column names to compare
y = "y" # name of the reference column
和一些助手:
def as_pair(keys, y, xs):
""" Given key names, y name, and xs names
return a tuple of key, array-of-values"""
key = itemgetter(*keys)
value = itemgetter(y, * xs) # Python 3 syntax
def as_pair_(row):
return key(row), np.array(value(row))
return as_pair_
def init(x):
""" Init function for combineByKey
Initialize new StatCounter and merge first value"""
return StatCounter().merge(x)
def center(means):
"""Center a row value given a
dictionary of mean arrays
"""
def center_(row):
key, value = row
return key, value - means[key]
return center_
def prod(arr):
return arr[0] * arr[1:]
def corr(stddev_prods):
"""Scale the row to get 1 stddev
given a dictionary of stddevs
"""
def corr_(row):
key, value = row
return key, value/stddev_prods[key]
return corr_
,並轉換DataFrame
到對RDD
:
pairs = df.rdd.map(as_pair(keys, y, xs))
接下來,我們計算每組統計:
stats = (pairs
.combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
.collectAsMap())
means = {k: v.mean() for k, v in stats.items()}
注:隨着5000層的功能和7000組,應該在保持這種結構在內存中沒有任何問題。對於較大的數據集,您可能需要使用RDD和join
,但這樣會比較慢。
中心中的數據:
centered = pairs.map(center(means))
計算協方差:
covariance = (centered
.mapValues(prod)
.combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
.mapValues(StatCounter.mean))
最後的相關性:
stddev_prods = {k: prod(v.stdev()) for k, v in stats.items()}
correlations = covariance.map(corr(stddev_prods))
實施例的數據:
df = sc.parallelize([
("a", "b", "c", 0.5, 0.5, 0.3, 1.0),
("a", "b", "c", 0.8, 0.8, 0.9, -2.0),
("a", "b", "c", 1.5, 1.5, 2.9, 3.6),
("d", "e", "f", -3.0, 4.0, 5.0, -10.0),
("d", "e", "f", 15.0, -1.0, -5.0, 10.0),
]).toDF(["key1", "key2", "key3", "y", "x1", "x2", "x3"])
個
結果與DataFrame
:
df.groupBy(*keys).agg(*[corr(y, x) for x in xs]).show()
+----+----+----+-----------+------------------+------------------+
|key1|key2|key3|corr(y, x1)| corr(y, x2)| corr(y, x3)|
+----+----+----+-----------+------------------+------------------+
| d| e| f| -1.0| -1.0| 1.0|
| a| b| c| 1.0|0.9972300220940342|0.6513360726920862|
+----+----+----+-----------+------------------+------------------+
和上述方法提供:
correlations.collect()
[(('a', 'b', 'c'), array([ 1. , 0.99723002, 0.65133607])),
(('d', 'e', 'f'), array([-1., -1., 1.]))]
此解決方案,同時涉及一個位,頗爲彈性和能夠容易地調整以處理不同的數據分佈。 JIT也應該有進一步的提升。
組可以從100行到5-10百萬不等,組數將是7000。 – Harish