2016-04-27 100 views
12

在機器學習任務。我們應該得到一組具有約束的隨機w.r.t正態分佈。我們可以通過np.random.normal()得到一個正態分佈號,但它不提供任何綁定參數。我想知道如何做到這一點?如何在numpy的範圍內獲得正態分佈?

+4

當不可t正常的隨機樣本按定義分佈式數據是無界的? – Tom

回答

8

如果你正在尋找的Truncated normal distribution,SciPy的有一個功能叫truncnorm

這種分佈的標準形式是一個標準的正常截斷 區間[A,B] - 注意, a和b在標準法線的域 上定義。要轉換剪輯的值對於特定均值和 標準偏差,使用:

A,B =(myclip_a - my_mean)/ my_std,(myclip_b - my_mean)/ my_std

truncnorm取A和B作爲形狀參數。

>>> from scipy.stats import truncnorm 
>>> truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10) 
array([-1.83136675, 0.77599978, -0.01276925, 1.87043384, 1.25024188, 
     0.59336279, -0.39343176, 1.9449987 , -1.97674358, -0.31944247]) 

上面的例子是由有界-2和2,並返回10個隨機變元(使用.rvs()方法)

>>> min(truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10000)) 
-1.9996074381484044 
>>> max(truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10000)) 
1.9998486576228549 

下面是-6直方圖,6:

enter image description here

+0

爲什麼你不使用truncnorm(a = -2,b = 2,scale = 1) – maple

+2

只是爲了說清楚a和b是形狀參數,否則讀者可能會嘗試-2,2的比例不等於1 ,然後得到外部的隨機值[-2,2] – bakkal

12

參數化truncnorm複雜,所以這裏是轉換參數化的東西更直觀的功能:

from scipy.stats import truncnorm 

def get_truncated_normal(mean=0, sd=1, low=0, upp=10): 
    return truncnorm(
     (low - mean)/sd, (upp - mean)/sd, loc=mean, scale=sd) 


如何使用它?

  1. 實例與參數發生器:意味着標準偏差,和截斷範圍

    >>> X = get_truncated_normal(mean=8, sd=2, low=1, upp=10) 
    
  2. 然後,可以使用X,以產生值:

    >>> X.rvs() 
    6.0491227353928894 
    
  3. 或者,numpy a rray用N產生的值:

    >>> X.rvs(10) 
    array([ 7.70231607, 6.7005871 , 7.15203887, 6.06768994, 7.25153472, 
         5.41384242, 7.75200702, 5.5725888 , 7.38512757, 7.47567455]) 
    

甲視覺例

這裏是三個不同的截短的正態分佈的情節:

X1 = get_truncated_normal(mean=2, sd=1, low=1, upp=10) 
X2 = get_truncated_normal(mean=5.5, sd=1, low=1, upp=10) 
X3 = get_truncated_normal(mean=8, sd=1, low=1, upp=10) 

import matplotlib.pyplot as plt 
fig, ax = plt.subplots(3, sharex=True) 
ax[0].hist(X1.rvs(10000), normed=True) 
ax[1].hist(X2.rvs(10000), normed=True) 
ax[2].hist(X3.rvs(10000), normed=True) 
plt.show() 

enter image description here

+1

精彩的回答,謝謝! – Gabriel

+0

+1。但值得注意的是,如果函數內部立即使用'get_truncated_normal.rvs()',而不是在外部調用該函數,該函數將變得更快。當然,這隻有在你想要隨機抽籤時纔有用 –