回答

7

你必須clear the threshold第一,這僅適用於二進制分類:

from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel 
from pyspark.mllib.regression import LabeledPoint 

parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]), 
       LabeledPoint(0.0, [5.7,4.4,1.5,0.4]), 
       LabeledPoint(1.0, [6.7,3.1,4.4,1.4]), 
       LabeledPoint(0.0, [4.8,3.4,1.6,0.2]), 
       LabeledPoint(1.0, [4.4,3.2,1.3,0.2])] 

model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data)) 
model.threshold 
# 0.5 
model.predict(parsed_data[2].features) 
# 1 

model.clearThreshold() 
model.predict(parsed_data[2].features) 
# 0.9873840020002339 
+0

從文檔我找不到一種方法來做同樣的多類分類。你知道這是否可能嗎?我認爲唯一的辦法是做一個手動1對所有 –

+0

不,對不起... – desertnaut

+0

@MpizosDimitris,這需要改變實際功能。我剛纔在Scala中實現這一點,可以爲新的問題提供答案 –

0

我相信這個問題是在計算概率得分的預測整個訓練集。如果是這樣,我做了以下計算。不知道後仍然有效,但這是howI這樣做:

#get the original training data before it was converted to rows of LabelPoint. 
#let us assume it is otd (of type spark DataFrame) 
#let us extract the featureset as rdd by: 
fs=otd.rdd.map(lambda x:x[1:]) # assuming label is col 0. 

#the below is just a sample way of creating a Labelpoint rows.. 
parsedData= otd.rdd.map(lambda x: reg.LabeledPoint(int(x[0]-1),x[1:])) 

# now convert otd to a panda DataFrame as: 
ptd= otd.toPandas() 
m= ptd.shape[0] 
# train and get the model 
model=LogisticRegressionWithLBFGS.train(trainingData,numClasses=10) 


#Now store the model.predict rdd structures 
predict=model.predict(fs) 
pr=predict.collect() 

correct=0 
correct = ((ptd.label-1) == (pr)).sum() 
print((correct/m) *100) 

注意上面的是多級分類。

+0

@desertnaut,請看看這是否合理。 – sunny

+0

1)'trainingdata'沒有定義2)'fs'沒有用到3)不清楚你的代碼的結果是什麼,並且它確實提供了概率;這就是爲什麼提供虛擬數據並展示結果的好習慣,正如我所做的那樣4)「toPandas」不是一個好主意,因爲它只適用於'小'數據集(你甚至不需要Spark )5)問題在ML大多已解決:https://stackoverflow.com/questions/43631031/pyspark-how-to-get-classification-probabilities-from-multilayerperceptronclassi/43643426#43643426 – desertnaut

+0

@desertnaut,我跑這我們正在另一篇文章討論的數據集代碼。 fs作爲參數傳遞給預測。我的訓練數據是一個5000x400的矩陣,其中包含用於多分類分類器的標籤1-10。這是一個手寫數字,其中包含1-10的數字。我理解到Pandas()並不高效,但目標是計算概率。 – sunny