2016-06-01 116 views
1

我是Python和Sklearn的初學者。想知道我是否在這裏失去了一些東西。我收到以下警告消息:Python Sklearn - 棄用警告

DeprecationWarning:傳遞一維數組作爲數據在0.17 棄用,在0.19 willraise ValueError異常。

下面是代碼:再次

import numpy as np 
import matplotlib.pyplot as plt 
from sklearn.linear_model import SGDClassifier 
from sklearn.datasets.samples_generator import make_blobs 

def plot_sgd_separator(): 
    # we create 50 separable points 
    X, Y = make_blobs(n_samples=50, centers=2,random_state=0, cluster_std=0.60) 
    X = np.array(X).reshape((1, -1)) 


    # fit the model 
    clf = SGDClassifier(loss="hinge", alpha=0.01, 
         n_iter=200, fit_intercept=True) 
    clf.fit(X, Y) 

    # plot the line, the points, and the nearest vectors to the plane 
    xx = np.linspace(-1, 5, 10) 
    yy = np.linspace(-1, 5, 10) 

    X1, X2 = np.meshgrid(xx, yy) 
    Z = np.empty(X1.shape) 
    for (i, j), val in np.ndenumerate(X1): 
     x1 = val 
     x2 = X2[i, j] 
     p = clf.decision_function([x1, x2]) 
     Z[i, j] = p[0] 
    levels = [-1.0, 0.0, 1.0] 
    linestyles = ['dashed', 'solid', 'dashed'] 
    colors = 'k' 

    ax = plt.axes() 
    ax.contour(X1, X2, Z, levels, colors=colors, linestyles=linestyles) 
    ax.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired) 

    ax.axis('tight') 


if __name__ == '__main__': 
    plot_sgd_separator() 
    plt.show() 

感謝您的關注。順便說一句,我使用Python 3.5.1。

回答

0

如果您閱讀警告消息並進行一些調試,您會意識到警告是由於您對模型的輸入是單維的。你可以看到這個鏈接:Sklearn train model with single sample raises a DeprecationWarning糾正這一點。

我覺得你的代碼還有其他問題。當我運行它時,我看到X和Y中的數據點數量不一樣。 X有100,Y有50,這是一個更嚴重的問題,我覺得這需要先糾正。

1

我想你的問題回答了here,這可能是重複