2015-04-22 96 views
2

首先,手頭有問題。我正在編寫一個scikit-learn類的包裝,並且遇到了正確的語法問題。我想實現的是fit_transform功能,從而改變輸入僅輕微的重寫,然後調用它的super - 方法與新的參數:如何在Python中正確覆蓋和調用超級方法

from sklearn.feature_extraction.text import TfidfVectorizer 

class TidfVectorizerWrapper(TfidfVectorizer): 
    def __init__(self): 
     TfidfVectorizer.__init__(self) # is this even necessary? 

    def fit_transform(self, x, y=None, **fit_params): 
     x = [content.split('\t')[0] for content in x] # filtering the input 
     return TfidfVectorizer.fit_transform(self, x, y, fit_params) 
          # this is the critical part, my IDE tells me for 
          # fit_params: 'unexpected arguments' 

程序崩潰所有的地方,從一個Multiprocessing exception,沒有告訴我任何有用的東西。我如何正確地做到這一點?

附加信息:之所以我需要用這種方式來包裝它,是因爲我使用sklearn.pipeline.FeatureUnion來收集我的特徵提取器,然後將它們放入sklearn.pipeline.Pipeline。這樣做的結果是,我只能跨所有特徵提取器提供單個數據集 - 但不同的提取器需要不同的數據。我的解決方案是以容易分離的格式提供數據,並在不同的提取器中過濾不同的部分。如果這個問題有更好的解決方案,我也很樂意聽到。

編輯1: 添加**解壓字典似乎沒有改變什麼: Screenshot

編輯2: 我剛剛解決剩下的問題 - 我需要刪除構造函數重載。顯然,通過試圖調用父構造函數,希望讓所有實例變量都能正確啓動,我的確做了相反的事情。我的包裝不知道它可以預期什麼樣的參數。一旦我刪除了多餘的電話,一切都完美了。

+1

你嘗試更換'返回TfidfVectorizer.fit_transform(個體經營,X,Y,fit_params)'和'返回TfidfVectorizer.fit_transform (self,x,y,** fit_params)'? – user3012759

+0

也取決於'y'在方法簽名中的位置,可能需要「命名」'y':'返回TfidfVectorizer.fit_transform(self,x,y = y,** fit_params)' –

+0

@KlausD。同意,但fit_params幾乎肯定不是正確的東西傳入,因爲它是一個字典 – user3012759

回答

3

忘記解壓縮fit_params這是作爲dict傳遞,你想通過它作爲keyword arguments需要解包運算符**

from sklearn.feature_extraction.text import TfidfVectorizer 

class TidfVectorizerWrapper(TfidfVectorizer): 

    def fit_transform(self, x, y=None, **fit_params): 
     x = [content.split('\t')[0] for content in x] # filtering the input 
     return TfidfVectorizer.fit_transform(self, x, y, **fit_params) 

是instaed調用TfidfVectorizerfit_transform直接就可以通過super方法調用重載版本的一件事

from sklearn.feature_extraction.text import TfidfVectorizer 

class TidfVectorizerWrapper(TfidfVectorizer): 

    def fit_transform(self, x, y=None, **fit_params): 
     x = [content.split('\t')[0] for content in x] # filtering the input 
     return super(TidfVectorizerWrapper, self).fit_transform(x, y, **fit_params) 

要理解它檢查下面的例子

def foo1(**kargs): 
    print kargs 

def foo2(**kargs): 
    foo1(**kargs) 
    print 'foo2' 

def foo3(**kargs): 
    foo1(kargs) 
    print 'foo3' 

foo1(a=1, b=2) 

它打印字典{'a': 1, 'b': 2}

foo2(a=1, b=2) 

版畫既字典和foo2,但

foo3(a=1, b=2) 

引發錯誤,因爲我們發送的位置參數等於我們的字典foo1,不接受這樣的事情。然而,我們可以做

def foo4(**kargs): 
    foo1(x=kargs) 
    print 'foo4' 

,工作正常,但打印新的字典{'x': {'a': 1, 'b': 2}}

+0

可能值得添加一個或兩個關於'超()'回答其他問題OP有 – user3012759

+0

我嘗試了你的建議,我仍然收到完全相同的錯誤。我在主帖中也添加了一條評論。 – Arne

+0

我確實刪除了構造函數,也許它應該更加強調下劃線 – lejlot