2014-04-04 55 views
2

我寫了一個簡單的Python代碼來計算一個集合的熵,我試圖在Theano中寫同樣的東西。Theano中的熵和概率

import math 

# this computes the probabilities of each element in the set 
def prob(values): 
    return [float(values.count(v))/len(values) for v in values] 

# this computes the entropy 
def entropy(values): 
    p = prob(values) 
    return -sum([v*math.log(v) for v in p]) 

我試圖寫在Theno等效代碼,但我不知道如何做到這一點:

import theano 
import theano.tensor as T 

v = T.vector('v') # I create a symbolic vector to represent my initial values 
p = T.vector('p') # The same for the probabilities 

# this is my attempt to compute the probabilities which would feed vector p 
theano.scan(fn=prob,outputs_info=p,non_sequences=v,n_steps=len(values)) 

# considering the previous step would work, the entropy is just 
e = -T.sum(p*T.log(p)) 
entropy = theano.function([values],e) 

然而,掃描線是不正確的,我得到噸的錯誤。我不確定是否有簡單的方法來實現它(計算矢量的熵),還是必須在掃描功能上付出更多的努力。有任何想法嗎?

+0

Theano無法在列表上進行計算。您必須更新您的代碼才能使用ndarray。首先只用numpy來做到這一點。這應該已經加快你的代碼。 – nouiz

回答

0

除了nouiz提出的觀點之外,P不應該被聲明爲T.vector,因爲它將是你的向量值的計算結果。另外,爲了計算像熵這樣的東西,你不需要使用掃描(掃描引入了一個計算開銷,所以它只能被使用,因爲沒有其他的方式來計算你想要的或者減少內存使用)。你可以採取如下方法:

values = T.vector('values') 
nb_values = values.shape[0] 

# For every element in 'values', obtain the total number of times 
# its value occurs in 'values'. 
# NOTE : I've done the broadcasting a bit more explicitly than 
# needed, for clarity. 
freqs = T.eq(values[:,None], values[None, :]).sum(0).astype("float32") 

# Compute a vector containing, for every value in 'values', the 
# probability of that value in the vector 'values'. 
# NOTE : these probabilities do *not* sum to 1 because they do not 
# correspond to the probability of every element in the vector 'values 
# but to the probability of every value in 'values'. For instance, if 
# 'values' is [1, 1, 0] then 'probs' will be [2/3, 2/3, 1/3] because the 
# value 1 has probability 2/3 and the value 0 has probability 1/3 in 
# values'. 
probs = freqs/nb_values 

entropy = -T.sum(T.log2(probs)/nb_values) 
fct = theano.function([values], entropy) 

# Will output 0.918296... 
print fct([0, 1, 1])