2015-09-17 62 views
4

我需要一些幫助讓我的大腦圍繞在spark中設計(高效)markov鏈(通過python)。我已經儘可能地寫了它,但我想出的代碼並不能縮放。基本上,對於各種地圖階段,我編寫了自定義函數,並且它們適用於幾千個序列,但是當我們得到在20,000+(我有一些高達800K)的事情緩慢爬行。PySpark markov模型的算法/編碼幫助

對於那些你不熟悉馬爾科夫moodels,這是它​​的要點..

這是我的數據。我已經在這一點得到了在RDD的實際數據(沒有標頭)。

ID, SEQ 
500, HNL, LNH, MLH, HML 

我們期待在元組序列,所以

(HNL, LNH), (LNH,MLH), etc.. 

,我需要得到這一點..在我返回一個字典(用於數據的每一行),我再序列化和存儲在內存數據庫中。

{500: 
    {HNLLNH : 0.333}, 
    {LNHMLH : 0.333}, 
    {MLHHML : 0.333}, 
    {LNHHNL : 0.000}, 
    etc.. 
} 

因此,在本質上,每個序列與下一個組合(HNL,LNH成爲「HNLLNH」),那麼對於所有的可能的轉變我們指望它們的發生,然後通過的總數劃分(序列的組合)轉換(在這種情況下爲3)並獲得它們的出現頻率。

有上述3個過渡,以及其中的一個是HNLLNH ..所以對於HNLLNH 1/3 = 0.333

作爲邊沒有,我不知道,如果是相關的,但對於價值序列中的每個位置都是有限的。第一個位置(H/M/L),第二個位置(M/L),第三個位置(H,M,L)。

我的代碼之前做過的事情是收集()rdd,並使用我寫的函數將其映射幾次。這些函數首先將字符串轉換爲一個列表,然後將列表[1]與列表[2]合併,然後將列表[2]與列表[3]合併,然後將列表[3]與列表[4]合併,等等。了這樣的事情..

[HNLLNH],[LNHMLH],[MHLHML], etc.. 

那麼接下來的函數創建一個字典出該列表中,使用列表項作爲重點,然後計算的完整列表是關鍵的總ocurrence,len個分(列表)來獲取頻率。然後,我將這本字典包裝在另一個字典中,並附上它的ID號碼(導致第二個代碼塊,在上面)。

就像我說過的,這對小型序列很有效,但對於長度爲100k +的列表來說效果不好。

另外,請記住,這只是一行數據。我必須在10-20k行數據的任何地方執行此操作,數據行在每行500-800,000個序列的長度之間變化。

關於如何可以寫pyspark代碼(使用API​​ map/reduce/agg/etc ..函數)來有效地執行此操作的任何建議?

編輯 代碼如下..可能有意義,從底部開始。請記住我學習這個(Python和星火),因爲我去,我不爲生活做到這一點,所以我的編碼標準都不是很大..

def f(x): 
    # Custom RDD map function 
    # Combines two separate transactions 
    # into a single transition state 

    cust_id = x[0] 
    trans = ','.join(x[1]) 
    y = trans.split(",") 
    s = '' 
    for i in range(len(y)-1): 
     s= s + str(y[i] + str(y[i+1]))+"," 
    return str(cust_id+','+s[:-1]) 

def g(x): 
    # Custom RDD map function 
    # Calculates the transition state probabilities 
    # by adding up state-transition occurrences 
    # and dividing by total transitions 
    cust_id=str(x.split(",")[0]) 
    trans = x.split(",")[1:] 
    temp_list=[] 
    middle = int((len(trans[0])+1)/2) 
    for i in trans: 
     temp_list.append((''.join(i)[:middle], ''.join(i)[middle:])) 

    state_trans = {} 
    for i in temp_list: 
      state_trans[i] = temp_list.count(i)/(len(temp_list)) 

    my_dict = {} 
    my_dict[cust_id]=state_trans 
    return my_dict 


def gen_tsm_dict_spark(lines): 
    # Takes RDD/string input with format CUST_ID(or)PROFILE_ID,SEQ,SEQ,SEQ.... 
    # Returns RDD of dict with CUST_ID and tsm per customer 
    # i.e. {cust_id : { ('NLN', 'LNN') : 0.33, ('HPN', 'NPN') : 0.66} 

    # creates a tuple ([cust/profile_id], [SEQ,SEQ,SEQ]) 
    cust_trans = lines.map(lambda s: (s.split(",")[0],s.split(",")[1:])) 

    with_seq = cust_trans.map(f) 

    full_tsm_dict = with_seq.map(g) 

    return full_tsm_dict 


def main(): 
result = gen_tsm_spark(my_rdd) 

# Insert into DB 
for x in result.collect(): 
    for k,v in x.iteritems(): 
     db_insert(k,v) 
+0

'collect()rdd,並使用我寫的函數將它映射了幾次'你不應該那樣做,你需要始終保持在RDD中,或者你不需要Spark。你能發佈你的實際代碼嗎? – hellpanderrr

+0

是的,我意識到,當我追查我的性能問題,以我自己的職能.. :) 不幸的是代碼被封裝在一個更大的程序,提取它會有點困難和凌亂,但我會盡我所能..只是要記住,我不是一個在任何方式的Python的專業人士.. :)將代碼添加到上面作爲編輯 – nameBrandon

+0

@nameBrandon什麼是db_insert和你在哪裏定義它? – DimKoim

回答

1

你可以嘗試像下面。它很大程度上取決於tooolz,但如果您更喜歡避免外部依賴性,則可以輕鬆地用一些標準Python庫替換它。

from __future__ import division 
from collections import Counter 
from itertools import product 
from toolz.curried import sliding_window, map, pipe, concat 
from toolz.dicttoolz import merge 

# Generate all possible transitions 
defaults = sc.broadcast(dict(map(
    lambda x: ("".join(concat(x)), 0.0), 
    product(product("HNL", "NL", "HNL"), repeat=2)))) 

rdd = sc.parallelize(["500, HNL, LNH, NLH, HNL", "600, HNN, NNN, NNN, HNN, LNH"]) 

def process(line): 
    """ 
    >>> process("000, HHH, LLL, NNN") 
    ('000', {'LLLNNN': 0.5, 'HHHLLL': 0.5}) 
    """ 
    bits = line.split(", ") 
    transactions = bits[1:] 
    n = len(transactions) - 1 
    frequencies = pipe(
     sliding_window(2, transactions), # Get all transitions 
     map(lambda p: "".join(p)), # Joins strings 
     Counter, # Count 
     lambda cnt: {k: v/n for (k, v) in cnt.items()} # Get frequencies 
    ) 
    return bits[0], frequencies 

def store_partition(iter): 
    for (k, v) in iter: 
     db_insert(k, merge([defaults.value, v])) 

rdd.map(process).foreachPartition(store_partition) 

由於您知道所有可能的轉換,我建議使用稀疏表示並忽略零。此外,您可以用稀疏矢量替換字典以減少內存佔用。

+0

3337340283179321,NNN,LPN,NPN,LNN,LNN,NPN,NPN,LNN等..... – nameBrandon

+0

什麼是db_insert及其定義的位置? – DimKoim

+0

@DimKoim'db_insert'來自原始帖子。這不是一個庫函數。 – zero323