2016-05-17 70 views
4

我想在pySpark mllib中構建一個簡單的自定義估算器。我有here,它可以寫一個自定義的變壓器,但我不知道如何在Estimator上做到這一點。我也不明白什麼@keyword_only做什麼,爲什麼我需要這麼多的制定者和獲得者。 Scikit學習似乎有定製機型see here適當的文件(但pySpark不如何在PySpark中自定義估算器mllib

僞爲例型號代碼:

class NormalDeviation(): 
    def __init__(self, threshold = 3): 
    def fit(x, y=None): 
     self.model = {'mean': x.mean(), 'std': x.std()] 
    def predict(x): 
     return ((x-self.model['mean']) > self.threshold * self.model['std']) 
    def decision_function(x): # does ml-lib support this? 

回答

9

一般來說沒有文檔,因爲作爲星火1.6/2.0最相關的API並不打算是公共的,應在星火2.1.0(見SPARK-7146)更改。

API是比較複雜的,因爲它必須遵循特定的慣例,以使給定TransformerEstimator兼容與Pipeline API。這些方法中的一些可能是讀寫和網格搜索等功能所必需的。其他,如keyword_only只是一個簡單的幫手,而不是嚴格要求。

假設您已經定義了以下的配料插件均值參數:

from pyspark.ml.pipeline import Estimator, Model, Pipeline 
from pyspark.ml.param.shared import * 
from pyspark.sql.functions import avg, stddev_samp 


class HasMean(Params): 

    mean = Param(Params._dummy(), "mean", "mean", 
     typeConverter=TypeConverters.toFloat) 

    def __init__(self): 
     super(HasMean, self).__init__() 

    def setMean(self, value): 
     return self._set(mean=value) 

    def getMean(self): 
     return self.getOrDefault(self.mean) 

標準偏差參數:

class HasStandardDeviation(Params): 

    stddev = Param(Params._dummy(), "stddev", "stddev", 
     typeConverter=TypeConverters.toFloat) 

    def __init__(self): 
     super(HasStandardDeviation, self).__init__() 

    def setStddev(self, value): 
     return self._set(stddev=value) 

    def getStddev(self): 
     return self.getOrDefault(self.stddev) 

和門檻:

class HasCenteredThreshold(Params): 

    centered_threshold = Param(Params._dummy(), 
      "centered_threshold", "centered_threshold", 
      typeConverter=TypeConverters.toFloat) 

    def __init__(self): 
     super(HasCenteredThreshold, self).__init__() 

    def setCenteredThreshold(self, value): 
     return self._set(centered_threshold=value) 

    def getCenteredThreshold(self): 
     return self.getOrDefault(self.centered_threshold) 

您可以創建基本Estimator爲如下:

class NormalDeviation(Estimator, HasInputCol, 
     HasPredictionCol, HasCenteredThreshold): 

    def _fit(self, dataset): 
     c = self.getInputCol() 
     mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first() 
     return (NormalDeviationModel() 
      .setInputCol(c) 
      .setMean(mu) 
      .setStddev(sigma) 
      .setCenteredThreshold(self.getCenteredThreshold()) 
      .setPredictionCol(self.getPredictionCol())) 

class NormalDeviationModel(Model, HasInputCol, HasPredictionCol, 
     HasMean, HasStandardDeviation, HasCenteredThreshold): 

    def _transform(self, dataset): 
     x = self.getInputCol() 
     y = self.getPredictionCol() 
     threshold = self.getCenteredThreshold() 
     mu = self.getMean() 
     sigma = self.getStddev() 

     return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma) 

最後,可以使用如下:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"]) 

normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0) 
model = Pipeline(stages=[normal_deviation]).fit(df) 

model.transform(df).show() 
## +---+----+----------+ 
## | id| x|prediction| 
## +---+----+----------+ 
## | 1| 2.0|  false| 
## | 2| 3.0|  false| 
## | 3| 0.0|  false| 
## | 4|99.0|  true| 
## +---+----+----------+ 
+0

的感謝!所以Estimator的狀態也被認爲是一個參數? –

+0

您是否將估算器的參數調整爲模型參數?如果是這樣,這種設計方式很方便,但對於基本實現來說並不難。 – zero323

+0

好的,任何希望得到一些關於如何堅持像這樣的自定義步驟的建議? –

相關問題