0
我的文本分類問題的工作,使用管道看起來像這樣的訓練數據scikit學習網格搜索:我如何做是一個迭代
self.full_classifier = Pipeline([
('vectorize', CountVectorizer()),
('tf-idf', TfidfTransformer()),
('classifier', SVC(kernel='linear', class_weight='balanced'))
])
完整的語料庫過大以適應內存,但足夠小,在矢量化步驟後我沒有內存問題。我可以用
self.full_classifier.fit(
self._all_data (max_samples=train_data_length),
self.dataset.head(train_data_length)['target'].values
)
其中self._all_data是產生每個訓練樣例文件(而self.dataset只是包括文件ID和目標)的迭代器成功地適應分類。在這裏,max_samples是可選的,我正在使用它來對訓練/測試數據進行拆分。我現在想用gridsearch優化參數,其中我使用這個代碼:
parameters = {
'vectorize__stop_words': (None, 'english'),
'tfidf__use_idf': (True, False),
'classifier__class_weight': (None, 'balanced')
}
gridsearch_classifier = GridSearchCV(self.full_classifier, parameters, n_jobs=-1)
gridsearch_classifier.fit(self._all_data(), self.dataset['target'].values)
我的問題是,這會產生以下錯誤:
TypeError: Expected sequence or array-like, got <type 'generator'>
與在gridsearch_classifier回溯指點。 (然後到scikit的代碼中,在_num_samples(x)中引發錯誤。因爲它可以適合一個生成器作爲輸入,所以我想知道是否有一種方法可以在網格搜索中做到這一點,我目前丟失。 任何幫助表示讚賞!
謝謝,這是有道理的。我會考慮通過實現一個到達數據庫的__getitem__來僞造一個列表 – Leo