2015-11-25 29 views
4

我們正試圖在Spark上使用MLLIB在Python中訓練一個具有指定 初始模型的高斯混合模型(GMM)。 pyspark的Doc 1.5.1說我們應該使用一個GaussianMixtureModel對象作爲輸入 作爲GaussianMixture.train方法的「initialModel」參數。 在創建我們自己的初始模型(計劃是使用Kmean 結果)之前,我們只是想測試這種情況。 所以我們嘗試使用第一次訓練的輸出中的GaussianMixtureModel初始化第二次訓練。 但這個微不足道的方案會引發錯誤。 你能幫我們確定這裏發生了什麼嗎? 非常感謝 紀堯姆如何使用初始GaussianMixtureModel訓練GMM?

PS:我們運行(PY)火花1.5.1使用Hadoop 2.6

下面是瑣碎的場景代碼和錯誤:

from pyspark.mllib.clustering import GaussianMixture 
from numpy import array 
import sys 
import os 
import pyspark 

### Local default options 
K=2 # "k" (int) Set the number of Gaussians in the mixture model. Default: 2 
convergenceTol=1e-3 # "convergenceTol" (double) Set the largest change in log-likelihood at which convergence is considered to have occurred. 
maxIterations=100 # "maxIterations" (int) Set the maximum number of iterations to run. Default: 100 
seed=None # "seed" (long) Set the random seed 
initialModel=None 

### Load and parse the sample data 
data = sc.textFile("gmm_data.txt") # Data from the dummy set here: data/mllib/gmm_data.txt 
parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) 
print type(parsedData) 
print type(parsedData.first()) 

### 1st training: Build the GMM 
gmm = GaussianMixture.train(parsedData, K, convergenceTol, 
maxIterations, seed, initialModel) 

# output parameters of model 
for i in range(2): 
    print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, 
     "sigma = ", gmm.gaussians[i].sigma.toArray()) 

### 2nd training: Re-build a GMM using an initial model 
initialModel = gmm 
print type(initialModel) 
gmm = GaussianMixture.train(parsedData, K, convergenceTol, maxIterations, seed, initialModel) 

而這與輸出錯誤:

<class 'pyspark.rdd.PipelinedRDD'> 
<type 'numpy.ndarray'> 
('weight = ', 0.51945003367044018, 'mu = ', DenseVector([-0.1045, 
0.0429]), 'sigma = ', array([[ 4.90706817, -2.00676881], 
     [-2.00676881, 1.01143891]])) 
('weight = ', 0.48054996632955982, 'mu = ', DenseVector([0.0722, 
0.0167]), 'sigma = ', array([[ 4.77975653, 1.87624558], 
     [ 1.87624558, 0.91467242]])) 
<class 'pyspark.mllib.clustering.GaussianMixtureModel'> 

--------------------------------------------------------------------------- 
Py4JJavaError        Traceback (most recent call last) 
<ipython-input-30-0008fe75eb61> in <module>() 
    33 initialModel = gmm 
    34 print type(initialModel) 
---> 35 gmm = GaussianMixture.train(parsedData, K, convergenceTol, 
maxIterations, seed, initialModel) # 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/clustering.pyc 
in train(cls, rdd, k, convergenceTol, maxIterations, seed, 
initialModel) 
    306   java_model = 
callMLlibFunc("trainGaussianMixtureModel", 
rdd.map(_convert_to_vector), 
    307         k, convergenceTol, 
maxIterations, seed, 
--> 308         initialModelWeights, 
initialModelMu, initialModelSigma) 
    309   return GaussianMixtureModel(java_model) 
    310 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc 
in callMLlibFunc(name, *args) 
    128  sc = SparkContext._active_spark_context 
    129  api = getattr(sc._jvm.PythonMLLibAPI(), name) 
--> 130  return callJavaFunc(sc, api, *args) 
    131 
    132 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc 
in callJavaFunc(sc, func, *args) 
    120 def callJavaFunc(sc, func, *args): 
    121  """ Call Java Function """ 
--> 122  args = [_py2java(sc, a) for a in args] 
    123  return _java2py(sc, func(*args)) 
    124 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc 
in _py2java(sc, obj) 
    86  else: 
    87   data = bytearray(PickleSerializer().dumps(obj)) 
---> 88   obj = sc._jvm.SerDe.loads(data) 
    89  return obj 
    90 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py 
in __call__(self, *args) 
    536   answer = self.gateway_client.send_command(command) 
    537   return_value = get_return_value(answer, self.gateway_client, 
--> 538     self.target_id, self.name) 
    539 
    540   for temp_arg in temp_args: 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/sql/utils.pyc in 
deco(*a, **kw) 
    34  def deco(*a, **kw): 
    35   try: 
---> 36    return f(*a, **kw) 
    37   except py4j.protocol.Py4JJavaError as e: 
    38    s = e.java_exception.toString() 

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/protocol.py 
in get_return_value(answer, gateway_client, target_id, name) 
    298     raise Py4JJavaError(
    299      'An error occurred while calling {0}{1}{2}.\n'. 
--> 300      format(target_id, '.', name), value) 
    301    else: 
    302     raise Py4JError(

Py4JJavaError: An error occurred while calling 
z:org.apache.spark.mllib.api.python.SerDe.loads. 
: net.razorvine.pickle.PickleException: expected zero arguments for 
construction of ClassDict (for numpy.core.multiarray._reconstruct) 
at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23) 
at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:701) 
at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:171) 
at net.razorvine.pickle.Unpickler.load(Unpickler.java:85) 
at net.razorvine.pickle.Unpickler.loads(Unpickler.java:98) 
at org.apache.spark.mllib.api.python.SerDe$.loads(PythonMLLibAPI.scala:1462) 
at org.apache.spark.mllib.api.python.SerDe.loads(PythonMLLibAPI.scala) 
at sun.reflect.GeneratedMethodAccessor31.invoke(Unknown Source) 
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) 
at java.lang.reflect.Method.invoke(Method.java:606) 
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231) 
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:379) 
at py4j.Gateway.invoke(Gateway.java:259) 
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) 
at py4j.commands.CallCommand.execute(CallCommand.java:79) 
at py4j.GatewayConnection.run(GatewayConnection.java:207) 
at java.lang.Thread.run(Thread.java:745) 
+1

它看起來像一個錯誤。我爲此開了一個JIRA([SPARK-12006](https://issues.apache.org/jira/browse/SPARK-12006))。 – zero323

+0

太好了,我們將根據pull request的更新進行操作,謝謝zero323 –

回答

相關問題