2014-01-20 74 views
0

沒有收斂我有3個高斯混合但是不管我多麼調整先驗我不能讓後意味着從他們的前值移動..混合高斯在pyMC3

k = 3 

n1 = 1000 
n2 = 1000 
n3 = 1000 

n = n1+n2+n3 

mean1 = 17.3 
mean2 = 42.0 
mean3 = 31.0 

precision = 0.1 

sigma = np.sqrt(1/precision) 

print "Standard deviation: %s" % sigma 

data1 = np.random.normal(mean1,sigma,n1) 
data2 = np.random.normal(mean2,sigma,n2) 
data3 = np.random.normal(mean3,sigma,n3) 

data = np.concatenate([data1 , data2, data3]) 

hist(data, bins=200, color="k", histtype="stepfilled", alpha=0.8) 
plt.title("Histogram of the dataset") 
plt.ylim([0, None]) 

with pm.Model() as model: 
    dd = pm.Dirichlet('dd', a=np.array([float(n/k) for i in range(k)]), shape=k) 
    sd = pm.Uniform('precs', lower=1, upper=5, shape=k) 
    means = pm.Normal('means', [25, 30, 35], 0.01, shape=k) 
    category = pm.Categorical('category', p=dd, shape=n) 

    points = pm.Normal('obs', 
        means[category], 
        sd=sd[category], 
        observed=data) 
    tr = pm.sample(100000, step=pm.Metropolis()) 
    pm.traceplot(tr, vars=['means', 'precs', 'dd']) 

輸出:

Standard deviation: 3.16227766017 
[-----------------100%-----------------] 100000 of 100000 complete in 157.2 sec 

正如你可以看到有沒有收斂和手段不從初始值移動 histogram of data traceplots not converging

回答