2017-06-06 62 views
0

我在寫一個名爲XGB的類,它從XGBClassifier(從python庫xgboost.sklearn)繼承。我寫了一個初始化功能和一個合適的一個,如下所示:python繼承:新類沒有正確初始化

from xgboost.sklearn import XGBClassifier 
from balanceSmote import BalanceSmote 
from balance import Balance 


class XGB(XGBClassifier): 

def __init__(self,learning_rate=0.5, max_depth=3,colsample_bytree=0.5,n_estimators=300, 
      frac=None,k_neighbors=None,m_neighbors=None,out_step=None): 

    # These are the additional arguments that are not in XGBClassifier 
    if k_neighbors: 
     self.balancingStrategy = 'smote' 
     self.k_neighbors = k_neighbors 
     self.m_neighbors = m_neighbors 
     self.out_step = out_step 
    elif frac : 
     self.balancingStrategy = 'normal' 
     self.frac = frac 
    else: 
     self.balancingStrategy = 'false' 

    # Utilize the motherClass 
    super(XGB,self).__init__(seed=500, 
         learning_rate = learning_rate, 
         max_depth = max_depth, 
         colsample_bytree = colsample_bytree, 
         n_estimators = n_estimators) 

這裏是我的測試代碼:

xgb4 = XGB(learning_rate = 0.1, max_depth = 3, colsample_bytree = 1, n_estimators = 1000) 

xgb4.fit(trainData,trainLabel) 

初始化似乎順利,但是當我嘗試使用配合( )(這是)從XGBClassifier繼承的方法,我有一個錯誤信息,告訴我一個參數丟失:

File "<ipython-input-3-47344b7fbc76>", line 1, in <module> 
runfile('/Users/celsloaner/Project/SPUDS/code/testSpark.py', wdir='/Users/celsloaner/Project/SPUDS/code') 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 880, in runfile 
execfile(filename, namespace) 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile 
exec(compile(f.read(), filename, 'exec'), namespace) 

File "/Users/celsloaner/Project/SPUDS/code/testSpark.py", line 50, in <module> 
xgb4.fit(predictor.trainData,predictor.trainLabel) 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 396, in fit 
xgb_options = self.get_xgb_params() 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 177, in get_xgb_params 
xgb_params = self.get_params() 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 169, in get_params 
if params['missing'] is np.nan: 

KeyError: 'missing' 

的問題來自於母親類中,應該已經正確初始化了。這裏是母親類問題的功能:

def get_params(self, deep=False): 
    """Get parameter.s""" 
    params = super(XGBModel, self).get_params(deep=deep) 
    if params['missing'] is np.nan: 
     params['missing'] = None # sklearn doesn't handle nan. see #4725 
    if not params.get('eval_metric', True): 
     del params['eval_metric'] # don't give as None param to Booster 
    return params 

字典params爲明顯不正確定義(關鍵「失蹤」不存在)時XGBClassifier初始化稱爲XGB初始化。 你有什麼想法是什麼問題或如何跟蹤它?

感謝

+0

如果沒有條目'params'例如''missing''條目,您將通過嘗試訪問它''KeyError'。修改條件:'如果params ['missing']是np.nan:'to:'如果params.get('missing',np.nan)是np.nan:'和'eval_metric'相同 – alfasin

+0

The問題是'params = super(XGBModel,self).get_params(deep = deep)',它返回一個沒有你想要的「dict」。 –

+0

我明白了,但爲什麼?該字典是在母類內部創建和管理的,我沒有碰它,所以我不明白爲什麼它在母類中沒有很好地定義,當我在繼承類中使用它時。 – Salamandre

回答

0

那麼,它工作時,我初始化所有的母親類的參數,即使母親類的構造函數應該有它們的默認值:

class XGB(XGBClassifier): 

    def __init__(self,max_depth=3, learning_rate=0.1, 
       n_estimators=100, silent=True, 
       objective="binary:logistic", 
       nthread=-1, gamma=0, min_child_weight=1, 
       max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, 
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, 
       base_score=0.5, seed=0, missing=None, 
       frac=None,k_neighbors=None,m_neighbors=None,out_step=None): 

     if k_neighbors: 
      self.balancingStrategy = 'smote' 
      self.k_neighbors = k_neighbors 
      self.m_neighbors = m_neighbors 
      self.out_step = out_step 
     elif frac : 
      self.balancingStrategy = 'normal' 
      self.frac = frac 
     else: 
      self.balancingStrategy = 'false' 


     super(XGB,self).__init__(max_depth, learning_rate, 
           n_estimators, silent, objective, 
           nthread, gamma, min_child_weight, 
           max_delta_step, subsample, 
           colsample_bytree, colsample_bylevel, 
           reg_alpha, reg_lambda, 
           scale_pos_weight, base_score, seed, missing) 

不知道我理解的邏輯在這裏,但它總是很高興知道:)