2016-02-28 37 views
0

在我的函數中,有時我得到的結果是2D形式的一維numpy數組,因此它的形狀是nx1(n,1)。其他時候,我可能會得到它的形式1xn array.shape =(1,n)測試Numpy數組以查看它是否爲列形式

其他時候,我只得到一個numpy數組,其形狀是(n,)。

當我運行下面的測試中,我得到一方面的錯誤,並且在另一假陽性(因爲一個形狀屬性的長度總是大於1,顯然):

y_predicted = forest.predict(testX) 
    if y_predicted.shape[1] != None: 
     y_predicted = y_predicted.T[0] 

y_predicted = forest.predict(testX) 
    if len(y_predicted.shape) > 1: 
     y_predicted = y_predicted.T[0] 

我只是需要確保y的最終形狀總是在形式(N),而不是(N,1)或(1,N)...

+0

'squeeze','ravel'和'flatten'都會做這個工作;但請閱讀他們的文檔,以便了解他們的差異。 – hpaulj

回答

3

你應該使用numpy.squeeze

numpy.squeeze(a)從數組形狀中刪除單維條目。

例子:

>>> x = np.array([[1,2,3]]) 
>>> x.shape 
(1, 3) 
>>> np.squeeze(x) 
array([1, 2, 3]) 
>>> np.squeeze(x).shape 
(3,) 
+0

很好的答案。不知道這種方法。 –

相關問題