2017-08-26 57 views
0

所以我在MNIST數據集上使用一個簡單的SGDClassifier(根據動手ML書),我似乎無法弄清楚它的decision_function的行爲。sklearn SGDClassifier的decision_function奇怪的行爲

我改變了original_decision函數的最後一行來專門檢查是否有什麼不同。後綴「check」的變量是decision_function在這裏返回的變量。的代碼:

from sklearn.datasets import fetch_mldata 
mnist = fetch_mldata("MNIST original") 
import numpy as np 
from sklearn.linear_model import SGDClassifier 
from sklearn.utils.extmath import safe_sparse_dot 
from sklearn.utils import check_array 

X, y = mnist["data"], mnist["target"] 
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] 

shuffle_index = np.random.permutation(60000) 
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index] 

# converting the problem into a binary classification problem. 
y_train_5 = (y_train == 5) 
y_test_5 = (y_test == 5) 

sgd_clf = SGDClassifier(random_state=42) 
sgd_clf.fit(X_train, y_train_5) 

# modified decision_func to ouput vars 
X_check, coef_check, intcpt_check = sgd_clf.decision_function([X_train[36000]]) 

print((X_check == X_train[36000]).all()) 
print((coef_check == sgd_clf.coef_).all()) 
print((intcpt_check == sgd_clf.intercept_).all()) 

# using same funcs as used by decision_function to calc. 
# i.e. check_array safe_sparse_dot 
X_mod = check_array(X[36000].reshape(1,-1), "csr") 
my_score = safe_sparse_dot(X_mod,sgd_clf.coef_.T) + sgd_clf.intercept_ 
sk_score = safe_sparse_dot(X_check, coef_check.T) + intcpt_check 

print(my_score) 
print(sk_score) 

這裏是輸出(用於單次運行):

True 
True 
True 
[[ 49505.1725926]] 
[[-347904.18757136]] 

這是我給decision_function進行的修改(在註釋原始之前第二最後一行):

scores = safe_sparse_dot(X, self.coef_.T, 
           dense_output=True) + self.intercept_ 
return X, self.coef_, self.intercept_ 
#return scores.ravel() if scores.shape[1] == 1 else scores 

即使所涉及的3個實體(例如X,係數,截距)所有比賽能與我的變量真實,乘法仍然會導致完全不同的結果。

爲什麼會發生這種情況?

編輯:奇怪的是,我發現,如果我註釋掉負責洗牌,即數據集中的兩條線:

shuffle_index = np.random.permutation(60000) 
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index] 

問題消失...

+0

請不要張貼鏈接到數據的屏幕截圖。幫助你很難,而且你很少得到及時的答覆。相反,作爲[最小化,完整和可驗證示例](https://stackoverflow.com/help/mcve)的一部分,將示例數據內聯發佈。 –

回答

0

它可能是由於X_mod使用原始的「X」(無混洗)矩陣,而X_check使用X_train(混洗)索引。

即X_train [36000]!= X [36000]所以當然你的分數不應該是相同的。

+0

即使在意識到排列組合刪除錯誤後,我仍覺得很愚蠢:D哈哈。非常感謝!得到它了。 –