2017-09-25 80 views
2

我想評估一個單側截斷正態分佈的不同值的分位數和不同值的未截斷平均值。爲了提高效率,我想使用numpy廣播而不是Python循環。使用numpy廣播與scipy truncnorm

對於最小重複的例子,假設三個位數欲評價是[3.0, 2.0, 1.0],相應未截斷平均值是[6.0, 5.0, 4.0],該下限截止是在1.5,並且未截短標準偏差爲3.0

評估這些單獨工作如預期。如果我運行

import numpy as np 
from scipy.stats import truncnorm 
print truncnorm.logpdf(3.0, a=(1.5-6.0)/3.0, b=np.inf, loc=6.0, scale=3.0) 
print truncnorm.logpdf(2.0, a=(1.5-5.0)/3.0, b=np.inf, loc=5.0, scale=3.0) 
print truncnorm.logpdf(1.0, a=(1.5-4.0)/3.0, b=np.inf, loc=4.0, scale=3.0) 

我得到

-2.44840736626 
-2.3878150686 
-inf 

(最後一個值是-inf因爲1.0小於截止)。同時使用numpy廣播兩個值也按預期工作。如果我運行

print truncnorm.logpdf(
    np.array([3.0, 2.0]), 
    a=(1.5-np.array([6.0, 5.0]))/3.0, 
    b=np.inf, 
    loc=np.array([6.0, 5.0]), 
    scale=3.0 
) 
print truncnorm.logpdf(
    np.array([2.0, 1.0]), 
    a=(1.5-np.array([5.0, 4.0]))/3.0, 
    b=np.inf, 
    loc=np.array([5.0, 4.0]), 
    scale=3.0 
) 

我得到

[-2.44840737 -2.38781507] 
[-2.38781507  -inf] 

不過,如果我嘗試運行,以評估在時間三個值:

print truncnorm.logpdf(
    np.array([3.0, 2.0, 1.0]), 
    a=(1.5-np.array([6.0, 5.0, 4.0]))/3.0, 
    b=np.inf, 
    loc=np.array([6.0, 5.0, 4.0]), 
    scale=3.0 
) 

我得到一個錯誤:

Traceback (most recent call last): 
    File "truncnorm_error.py", line 25, in <module> 
    scale=3.0 
    File "C:\Python27\lib\site-packages\scipy\stats\_distn_infrastructure.py", line 1701, in logpdf 
    place(output, cond, self._logpdf(*goodargs) - log(scale)) 
    File "C:\Python27\lib\site-packages\scipy\stats\_continuous_distns.py", line 4853, in _logpdf 
    return _norm_logpdf(x) - self._logdelta 
ValueError: operands could not be broadcast together with shapes (2,) (3,) 

我錯過了什麼?我使用Python 2.7,numpy 1.13和scipy 0.19。

+0

看起來像一個錯誤。你可以通過https://github.com/scipy/scipy/issues創建一個問題(點擊大綠色的「新問題」按鈕)。 –

回答

0

這是行不通的原因,因爲logpdf檢查分位數,以確保它們大於臨界值。如果你的值小於截斷值,顯然它適用於大小爲1和2,但不適用於3.所以這可能是錯誤。

如果您提供的值大於截斷值,則可以正常工作。例如,這個工程,我把分位數從1.0改爲1.6:

print truncnorm.logpdf(
    np.array([3.0, 2.0, 1.6]), 
    a=(1.5-np.array([6.0, 5.0, 4.0]))/3.0, 
    b=np.inf, 
    loc=np.array([6.0, 5.0, 4.0]), 
    scale=3.0) 
+0

是的。我找到了同樣的東西。當其中一個分位數下降到截止點以下時,向量長度大於2時觸發該行爲。然而,我並不想在自己的代碼中添加額外的邏輯來處理截斷(因爲它會增加計算開銷),奇怪的是,只有在向量長度大於2時纔會出現這種情況。 – tcquinn

+0

看着' scipy'代碼,它看起來像函數'logpdf'用'-inf'填充它的輸出向量,並且只計算在截斷範圍內的值。當你有一個標量時,它直接返回。如果您有一個數值範圍都在範圍內和範圍外的數組,它具有一個選擇範圍內的值的函數。很顯然,當數組大小大於2時,這個函數會混淆。 –

0

謝謝,所有。在此期間,我滾我自己:

def left_truncnorm_logpdf(x, untruncated_mean, untruncated_std_dev, left_cutoff): 
    f = np.array(np.subtract(stats.norm.logpdf(x, loc=untruncated_mean, scale=untruncated_std_dev), 
          np.log(1 - stats.norm.cdf(left_cutoff, loc=untruncated_mean, scale=untruncated_std_dev)))) 
    f[x < left_cutoff] = -np.inf 
    return f 

這是不雅,我敢肯定它有問題,但它似乎對我的工作的目的(例如,它正確地播放載體論點xuntruncated_mean)。