2015-10-01 65 views
0

我正在通過Python API使用Spark進行一些數據處理。下面是我的工作之類的簡化一下:在Python類中註冊Spark SQL用戶定義函數

class data_processor(object): 
    def __init__(self,filepath): 
     self.config = Config() # this loads some config options from file 
     self.type_conversions = {int:IntegerType,str:StringType} 
     self.load_data(filepath) 
     self.format_age() 

    def load_data(self,filepath,delim='\x01'): 
     cols = [...] # list of column names 
     types = [int, str, str, ... ] # list of column types 
     user_data = sc.textFile(filepath,use_unicode=False).map(lambda row: [types[i](val) for i,val in enumerate(row.strip().split(delim))]) 
     fields = StructType([StructField(field_name,self.type_conversions[field_type]()) for field_name,field_type in zip(cols,types)]) 
     self.user_data = user_data.toDF(fields) 
     self.user_data.registerTempTable('data') 

    def format_age(self): 
     age_range = self.config.age_range # tuple of (age_min, age_max) 
     age_bins = self.config.age_bins # list of bin boundaries 
     def _format_age(age): 
      if age<age_range[0] or age>age_range[1]: 
       return None 
      else: 
       return np.digitize([age],age_bins)[0] 
     sqlContext.udf.register('format_age', lambda x: _format_age(x), IntegerType()) 

現在,如果我實例化data=data_processor(filepath)類的,我可以做查詢的數據幀就好了。例如,這個作品:

sqlContext.sql("select * from data limit 10").take(1) 

但我很明顯沒有正確設置udf。如果我嘗試,例如,

sqlContext.sql("select age, format_age(age) from data limit 10").take(1) 

我得到一個錯誤:

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe. 

(用長堆棧跟蹤,典型的星火,時間太長,包括在這裏)。

那麼,我究竟做錯了什麼?在像這樣的方法中定義UDF的方法是什麼(最好是作爲類方法)。我知道Spark不喜歡傳遞類對象,因此嵌套結構爲format_age(受this question啓發)。

想法?

回答

1

答案很簡單。您不能使用NumPy數據類型作爲Spark SQL中標準Python類型的直接替換。當您聲明返回類型爲IntegerType時,返回類型np.digitizenumpy.int64而非int

所有你需要做的就是投從_format_age返回的值:

def _format_age(age): 
    ... 
    return int(np.digitize(...))