2015-10-06 65 views
2

我有一個單一的3D數值數據文件,我從塊中讀取(因爲在塊中讀取比單個索引快)。例如說有一個MxNx30陣列中「文件」,我會創建一個這樣的RDD:在Pyspark的RDD分區中分割數組

def read(ind): 
    f = customFileOpener(file) 
    return f['data'][:,:,ind[0]:ind[-1]+1] 

indices = [[0,9],[10,19],[20,29]] 
rdd = sc.parallelize(indices,3).map(lambda v:read(v)) 
rdd.count() 

所以各3個分區的大小爲MxNx10的numpy.ndarray元件。

現在,我想分割每個分區中的每個元素,我有10個元素,每個元素是一個MxN數組。我試着用flatMap()用於此目的,但得到的錯誤「NoneType對象不是可迭代」:

def splitArr(arr): 
    Nmid = arr.shape[-1] 
    out = [] 
    for i in range(0,Nmid): 
     out.append(arr[...,i]) 
    return out 

rdd2 = rdd.flatMap(lambda v: splitArr(v)) 
rdd2.count() 

什麼是做這種正確的方法是什麼?關鍵點是(a)我需要從文件中以塊讀取數據和(b)拆分數據,因此元素的大小爲MxN(最好保留分區結構)。

回答

2

據我瞭解你的描述是這樣的應該做的伎倆:

rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2))) 

或者如果你喜歡一個單獨的函數:

def splitArr(arr): 
    for x in np.rollaxis(arr, 2): 
     yield x 

rdd.flatMap(splitArr) 
+0

我明白了,我應該從可迭代猜測錯誤。將尺寸移動到數組的前面並用rollaxis分割,然後迭代這些元素。正是我想要的,非常感謝。 – Michael