2016-02-29 159 views
1

首先設置一個數組元素與序列,這裏是我的代碼:ValueError異常:在Python

y[index] = (num[index])/sum_n

我運行代碼:

"""Softmax.""" 

scores = [3.0, 1.0, 0.2] 

import numpy as np 

def softmax(x): 
    """Compute softmax values for each sets of scores in x.""" 
    num = np.exp(x) 
    score_len = len(x) 
    y = np.array([0]*score_len) 
    sum_n = np.sum(num) 
    #print sum_n 
    for index in range(1,score_len): 
     y[index] = (num[index])/sum_n 
    return y 

print(softmax(scores)) 

錯誤在該行出現:

# Plot softmax curves 
import matplotlib.pyplot as plt 
x = np.arange(-2.0, 6.0, 0.1) 
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)]) 

plt.plot(x, softmax(scores).T, linewidth=2) 
plt.show() 

究竟是怎麼回事?

+0

您是否嘗試過調試嗎? – MSeifert

+0

錯誤很明顯... – Idos

+0

@MSeifert你是什麼意思調試它? –

回答

2

只需要編輯一個print聲明爲「調試」,揭示發生了什麼:

import numpy as np 

def softmax(x): 
    """Compute softmax values for each sets of scores in x.""" 
    num = np.exp(x) 
    score_len = len(x) 
    y = np.array([0]*score_len) 
    sum_n = np.sum(num) 
    #print sum_n 
    for index in range(1,score_len): 
     print((num[index])/sum_n) 
     y[index] = (num[index])/sum_n 
    return y 

x = np.arange(-2.0, 6.0, 0.1) 
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)]) 
softmax(scores).T 

此打印

[ 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 
    0.00065504 0.00065504] 

所以你想這個數組賦值給另一個數組的一個元素。哪個是不允許的!

有幾種方法可以使它工作。只是改變

y = np.array([0]*score_len) 

到多維數組會工作:

y = np.zeros(score.shape) 

這應該做的伎倆,但我不知道這是否是你的原意。


編輯:

看來你不想多維輸入,所以你只需要改變:

scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)]) 

scores = np.hstack([x, np.ones_like(x), 0.2 * np.ones_like(x)]) 

驗證這些形狀通過打印陣列scores.shape確實可以幫助你自己找到這樣的錯誤。第一個沿着第一軸(vstack)和零軸hstack棧(這是你想要的)

+0

我只想要一個單數。 'num [index]'應該是一個單一的數字,'sum_n'也是。爲什麼這個分部導致如此龐大的陣列? –

+0

編輯了答案。我認爲你使用''vstack''''hstack''會更合適。 – MSeifert

+0

'np.exp()'在開始時對列表做了什麼?我想要一個返回列表中每個元素的'exp()'的列表。 –

1

這初始化數組的好方法:

y = np.array([0]*score_len) 

也好像

y = np.zeros((n,m)) 

其中nm是最終產品的2個維度。我從你的另一個問題中假設你想y是2d(畢竟你後面做了.T)。

注意傳遞給函數的scores的形狀。並且在迭代時,包括:。它可以是可選的,但你需要它保持尺寸直在自己的腦海:

y[index,:] = (num[index,:])/sum_n 

總之 - 專注於瞭解如何使用多維數組的工作 - 如何創建它們,以及如何對其進行索引,如何在沒有迭代的情況下使用它們,以及如果需要的話如何正確迭代。

0

這應該很好地工作和快速

scores = [3.0, 1.0, 0.2] 

import numpy as np 


def softmax(x): 

    num = np.exp(x) 
    score_len = len(x) 

    y = np.zeros(score_len, object) # or => np.asarray([None]*score_len) 
    sum_n = np.sum(num) 

    for i in range(score_len): 
     y[i] = num[i]/sum_n 

    return y 


print(softmax(scores)) 

x = np.arange(-2.0, 6.0, 0.1) 
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)]) 

printout = softmax(scores).T 

print(printout) 

輸出:

[0.8360188027814407 0.11314284146556011 0.050838355752999158] 

[ array([ 3.26123038e-05, 3.60421698e-05, 3.98327578e-05, 
     4.40220056e-05, 4.86518403e-05, 5.37685990e-05, 
     5.94234919e-05, 6.56731151e-05, 7.25800169e-05, 
     8.02133239e-05, 8.86494329e-05, 9.79727751e-05, 
     1.08276662e-04, 1.19664218e-04, 1.32249413e-04, 
     1.46158206e-04, 1.61529798e-04, 1.78518035e-04, 
     1.97292941e-04, 2.18042421e-04, 2.40974142e-04, 
     2.66317614e-04, 2.94326482e-04, 3.25281069e-04, 
     3.59491177e-04, 3.97299194e-04, 4.39083515e-04, 
     4.85262332e-04, 5.36297817e-04, 5.92700751e-04, 
     6.55035633e-04, 7.23926331e-04, 8.00062328e-04, 
     8.84205618e-04, 9.77198335e-04, 1.07997118e-03, 
     1.19355274e-03, 1.31907978e-03, 1.45780861e-03, 
     1.61112768e-03, 1.78057146e-03, 1.96783579e-03, 
     2.17479489e-03, 2.40352006e-03, 2.65630048e-03, 
     2.93566604e-03, 3.24441273e-03, 3.58563059e-03, 
     3.96273465e-03, 4.37949910e-03, 4.84009504e-03, 
     5.34913227e-03, 5.91170543e-03, 6.53344491e-03, 
     7.22057331e-03, 7.97996764e-03, 8.81922816e-03, 
     9.74675448e-03, 1.07718296e-02, 1.19047128e-02, 
     1.31567424e-02, 1.45404491e-02, 1.60696814e-02, 
     1.77597446e-02, 1.96275532e-02, 2.16918010e-02, 
     2.39731477e-02, 2.64944256e-02, 2.92808687e-02, 
     3.23603645e-02, 3.57637337e-02, 3.95250385e-02, 
     4.36819230e-02, 4.82759910e-02, 5.33532213e-02, 
     5.89644285e-02, 6.51657716e-02, 7.20193157e-02, 
     7.95936532e-02, 8.79645908e-02]) 
array([ 0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504, 
     0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504]) 
array([ 0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433, 
     0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433])] 
+0

你能否修復你的代碼格式? – roelofs