2016-07-28 72 views
0

我使用星火1.6.2,我有以下數據結構:獲得每組的前N項pySpark

sample = sqlContext.createDataFrame([ 
        (1,['potato','orange','orange']), 
        (1,['potato','orange','yogurt']), 
        (2,['vodka','beer','vodka']), 
        (2,['vodka','beer','juice', 'vinegar']) 

    ],['cat','terms']) 

我想提取每貓前N個最常見的術語。我開發了以下解決方案,似乎可行,但我想看看是否有更好的方法來做到這一點。

from collections import Counter 
def get_top(it, terms=200): 
    c = Counter(it.__iter__()) 
    return [x[0][1] for x in c.most_common(terms)] 

(sample.select('cat',sf.explode('terms')).rdd.map(lambda x: (x.cat, x.col)) 
.groupBy(lambda x: x[0]) 
.map(lambda x: (x[0], get_top(x[1], 2))) 
.collect() 
) 

它提供了以下的輸出:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])] 

這是符合我所期待的,但我真的不喜歡,我訴諸使用計數器的事實。我怎樣才能用火花一個人做到這一點?

感謝

回答

1

如果這樣的工作可能是更好張貼這Code Review

就像一個練習,我做了這個沒有櫃檯,但很大程度上,你只是複製相同的功能。

  • 計數(catterm
  • 組的每次出現由cat
  • 排序基於計數和切片的值,以(2)項的數目

代碼:

from operator import add 

(sample.select('cat', sf.explode('terms')) 
.rdd 
.map(lambda x: (x, 1)) 
.reduceByKey(add) 
.groupBy(lambda x: x[0][0]) 
.mapValues(lambda x: [r[1] for r, _ in sorted(x, key=lambda a: -a[1])[:2]]) 
.collect()) 

輸出:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])] 
+0

不錯。直到mapValues之前,我已經掌握了大部分內容,但無法找出這一點。謝謝! – browskie

+0

之前,我訴諸計數器是... – browskie

+0

它可能是一個挑戰,跟蹤腳手架結構,提取數據返回。 – AChampion