2017-05-19 57 views
0

我很難在SQL數據框中乘以列的元素。在pyspark中乘以稀疏向量的行SQL DataFrame

sv1 = Vectors.sparse(3, [0, 2], [1.0, 3.0]) 
sv2 = Vectors.sparse(3, [0, 1], [2.0, 4.0]) 

def xByY(x,y): 
    return np.multiply(x,y) 

print(xByY(sv1, sv2)) 

上述工作。

但是下面沒有。

xByY_udf = udf(xByY) 

tempDF = sqlContext.createDataFrame([(sv1, sv2), (sv1, sv2)], ('v1', 'v2')) 
tempDF.show() 

print(tempDF.select(xByY_udf('v1', 'v2')).show()) 

非常感謝!

+0

什麼是你的錯誤? –

回答

0

如果你希望你的udf返回一個SparseVector,我們首先需要修改你的函數的輸出,二組的輸出模式的udfVectorUDT()

要聲明SparseVector,我們需要原始陣列的大小以及非零元素的索引。我們可以發現這些使用len()和list解析,如果乘法的中間結果是list

from pyspark.ml.linalg import Vectors, VectorUDT 

def xByY(x,y): 
    res = np.multiply(x,y).tolist() 
    vec_args = len(res), [i for i,x in enumerate(res) if x != 0], [x for x in res if x != 0] 
    return Vectors.sparse(*vec_args) 

現在我們可以宣佈我們udf並對其進行測試:

xByY_udf = udf(xByY, VectorUDT()) 
tempDF.select(xByY_udf('v1', 'v2')).show() 
+-------------+ 
| xByY(v1, v2)| 
+-------------+ 
|(3,[0],[2.0])| 
|(3,[0],[2.0])| 
+-------------+ 
+0

謝謝!這(幾乎)作品!我如何得到一個稀疏矢量? –

+0

請參閱更新@ f.g。 – mtoto

+0

非常出色,謝謝。 (UDF輸出模式的規範是我總是需要仔細考慮的)。 –