2016-01-28 184 views
1

我寫了這個代碼:Scikit學習圖像分類

# Import datasets, classifiers and performance metrics 
from sklearn import datasets, svm, metrics 
import matplotlib.image as mpimg 

imgs=[[mpimg.imread('sci/img/1.jpg'),mpimg.imread('sci/img/2.jpg')],[mpimg.imread('sci/img/3.jpg'),mpimg.imread('sci/img/4.jpg')]] 
targ=[1,2] 

# To apply a classifier on this data, we need to flatten the image, to 
# turn the data in a (samples, feature) matrix: 
n_samples = len(imgs) 
data = imgs.reshape((n_samples, -1)) 

# Create a classifier: a support vector classifier 
classifier = svm.SVC(gamma=0.001) 

# We learn the digits on the first half of the digits 
classifier.fit(data, targ) 

# Now predict the value of the digit on the second half: 
expected = targ 
predicted = classifier.predict(data) 

print("Classification report for classifier %s:\n%s\n" 
     % (classifier, metrics.classification_report(expected, predicted))) 
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted)) 

和我讀到這個錯誤:

AttributeError: 'list' object has no attribute 'reshape'

我想我是錯了建立圖像陣列,因爲它解決?

回答

0
data = imgs.reshape((n_samples, -1)) 

在這裏,你想申請的方法reshape Python列表上。

但是,imgs應該是numpy array。因此,你應該更換

imgs = [[mpimg.imread('sci/img/1.jpg'), mpimg.imread('sci/img/2.jpg')],[mpimg.imread('sci/img/3.jpg'), mpimg.imread('sci/img/4.jpg')]] 

import numpy as np 
imgs = np.array([[mpimg.imread('sci/img/1.jpg'), mpimg.imread('sci/img/2.jpg')], [mpimg.imread('sci/img/3.jpg'), mpimg.imread('sci/img/4.jpg')]])