2016-07-01 90 views
1

假設我想對從另一個矩陣中的條目定義的分佈中採樣的每個條目採樣一個矩陣。我展開我的矩陣並將map_fn應用於每個元素。使用相對較小的矩陣(128 x 128),以下給出了幾個PoolAllocator警告(GTX TITAN Black),並且不會在任何合理的時間內進行訓練。用map_fn進行元素採樣緩慢

def sample(x): 
    samples = tf.map_fn(lambda z: 
         tf.random_normal([1], mean=z, 
             stddev=tf.sqrt(z * (1 - z))), 
         tf.reshape(x, [-1])) # apply to each element 

    return tf.cond(is_training, lambda: tf.reshape(samples, shape=tf.shape(x)), 
        lambda: tf.tanh(x)) 

有沒有更好的方法來應用像這樣的元素操作?

回答

1

如果您可以使用張量一次操作而不是像tf.map_fn這樣的元素操作,那麼您的代碼運行速度會快得多。

這裏它看起來像你想從每個元素的正態分佈中抽樣,其中分佈的參數對於輸入張量中的每個值是不同的。嘗試是這樣的:

def sample(x): 
    samples = tf.random_normal(shape=[128, 128]) * tf.sqrt(x * (1 - x)) + x 

tf.random_normal()默認生成與平均值是0.0,標準偏差1.0正態分佈。您可以使用逐點張量操作來確定每個元素的標準偏差(通過乘法)和平均值(通過加法)。事實上,如果你看看tf.random_normal()是如何實現的,那麼它正是在內部完成的。

(您可能還做得更好使用Python的條件來區分測試時間訓練。)

如果你打算做這樣的事情很多,你可以在GitHub上提交功能請求,要求概括tf.random_normal接受meanstddev的更一般形狀的張量。我認爲沒有理由不支持。

希望有幫助!

0

請參閱tensorflow.contrib.distributions模塊,該模塊有Normal類和sample方法可以爲您完成此操作。