0
我在寫一個名爲XGB的類,它從XGBClassifier(從python庫xgboost.sklearn)繼承。我寫了一個初始化功能和一個合適的一個,如下所示:python繼承:新類沒有正確初始化
from xgboost.sklearn import XGBClassifier
from balanceSmote import BalanceSmote
from balance import Balance
class XGB(XGBClassifier):
def __init__(self,learning_rate=0.5, max_depth=3,colsample_bytree=0.5,n_estimators=300,
frac=None,k_neighbors=None,m_neighbors=None,out_step=None):
# These are the additional arguments that are not in XGBClassifier
if k_neighbors:
self.balancingStrategy = 'smote'
self.k_neighbors = k_neighbors
self.m_neighbors = m_neighbors
self.out_step = out_step
elif frac :
self.balancingStrategy = 'normal'
self.frac = frac
else:
self.balancingStrategy = 'false'
# Utilize the motherClass
super(XGB,self).__init__(seed=500,
learning_rate = learning_rate,
max_depth = max_depth,
colsample_bytree = colsample_bytree,
n_estimators = n_estimators)
這裏是我的測試代碼:
xgb4 = XGB(learning_rate = 0.1, max_depth = 3, colsample_bytree = 1, n_estimators = 1000)
xgb4.fit(trainData,trainLabel)
初始化似乎順利,但是當我嘗試使用配合( )(這是)從XGBClassifier繼承的方法,我有一個錯誤信息,告訴我一個參數丟失:
File "<ipython-input-3-47344b7fbc76>", line 1, in <module>
runfile('/Users/celsloaner/Project/SPUDS/code/testSpark.py', wdir='/Users/celsloaner/Project/SPUDS/code')
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 880, in runfile
execfile(filename, namespace)
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "/Users/celsloaner/Project/SPUDS/code/testSpark.py", line 50, in <module>
xgb4.fit(predictor.trainData,predictor.trainLabel)
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 396, in fit
xgb_options = self.get_xgb_params()
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 177, in get_xgb_params
xgb_params = self.get_params()
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 169, in get_params
if params['missing'] is np.nan:
KeyError: 'missing'
的問題來自於母親類中,應該已經正確初始化了。這裏是母親類問題的功能:
def get_params(self, deep=False):
"""Get parameter.s"""
params = super(XGBModel, self).get_params(deep=deep)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
return params
字典params爲明顯不正確定義(關鍵「失蹤」不存在)時XGBClassifier初始化稱爲XGB初始化。 你有什麼想法是什麼問題或如何跟蹤它?
感謝
如果沒有條目'params'例如''missing''條目,您將通過嘗試訪問它''KeyError'。修改條件:'如果params ['missing']是np.nan:'to:'如果params.get('missing',np.nan)是np.nan:'和'eval_metric'相同 – alfasin
The問題是'params = super(XGBModel,self).get_params(deep = deep)',它返回一個沒有你想要的「dict」。 –
我明白了,但爲什麼?該字典是在母類內部創建和管理的,我沒有碰它,所以我不明白爲什麼它在母類中沒有很好地定義,當我在繼承類中使用它時。 – Salamandre