2017-02-28 29 views
1

我想檢查我的注意力實現在TensorFlow中是否正確。在Tensorflow中實現注意力

基本上,我使用https://arxiv.org/pdf/1509.06664v1.pdf中提到的注意力。 (只是基本的關注,而不是逐字的關注)。到目前爲止,我沒有使用最後一個隱藏狀態h_N來實現它。

def attention(hidden_states): 
    ''' 
    hidden states (inputs) are seq_len x batch_size x dim 
    returns r, the weighted representation of the hidden states by attention vector a 

    Note:I do not use the h_N vector and also I skip the last projection layer. 
    ''' 
    shape = hidden_states.get_shape().as_list() 

    with tf.variable_scope('attention_{}'.format(name), reuse=reuse) as f: 
     initializer = tf.random_uniform_initializer() 

     # Initialize Parameters 
     weights_Y = tf.get_variable("weights_Y", [self.in_size, self.in_size], initializer=initializer) 
     weights_w = tf.get_variable("weights_w", [self.in_size, 1], initializer=initializer) 

     # Equation => M = tanh(W^{Y}Y) 
     tmp = tf.reshape(hidden_states, [-1, shape[2]]) 
     Y = tf.matmul(tmp, weights_Y) 
     Y = tf.reshape(Y, [shape[0], -1, shape[2]]) 
     Y = tf.tanh(Y, name='M_matrix') 

     # Equation => a = softmax(Y w^T) 
     Y = tf.reshape(Y, [-1, shape[2]]) 
     a = tf.matmul(Y, weights_w) 
     a = tf.reshape(a, [-1, shape[0]]) 
     a = tf.nn.softmax(a, name='attention_vector') 

     # Equation => r = Ya^T 
     # This is the part I weight all hidden states by the attention vector 
     a = tf.expand_dims(a, 2) 
     H = tf.transpose(hidden_states, [1,2,0]) 
     r = tf.matmul(H, a, name='R_vector') 
     r = tf.reshape(r, [-1, shape[2]]) 

     # I skip the last projection layer since I do not use h_N 
     return r 

此圖編譯,運行和正確訓練。 (損失正在下降等),但表現低於我的預期。如果我能做到這一點,我將不勝感激。

通常,

1)對於那些乘法[?,seq_len,暗淡]矩陣乘以[暗淡,暗淡]。從[?,seq_len,dim]到[-1,dim]使用tf.reshape並將[mat]與[dim,dim]應用於matmul並重新變形爲[?,seq_len,暗淡] matmul後?

2)我注意到我得到(?,seq_len)注意向量。所以我需要做(?,seq_len)x(?,dim,seq_len)。

將(?,seq_len)和expand_dims(?,seq_len,1)轉換爲(?,seq_len,1)是否正確,然後執行matmul(我認爲這是以前版本中的batch_matmul)。

在此先感謝!

回答

0

TF1.0中的tf.einsum不確定tf.einsum是有效實現的,但它會使計算相當優雅。

import tensorflow as tf 
import numpy as np 

batch_size = 3 
seq_len = 5 
dim = 2 
# [batch_size x seq_len x dim] -- hidden states 
Y = tf.constant(np.random.randn(batch_size, seq_len, dim), tf.float32) 
# [batch_size x dim]   -- h_N 
h = tf.constant(np.random.randn(batch_size, dim), tf.float32) 

initializer = tf.random_uniform_initializer() 
W = tf.get_variable("weights_Y", [dim, dim], initializer=initializer) 
w = tf.get_variable("weights_w", [dim], initializer=initializer) 

# [batch_size x seq_len x dim] -- tanh(W^{Y}Y) 
M = tf.tanh(tf.einsum("aij,jk->aik", Y, W)) 
# [batch_size x seq_len]  -- softmax(Y w^T) 
a = tf.nn.softmax(tf.einsum("aij,j->ai", M, w)) 
# [batch_size x dim]   -- Ya^T 
r = tf.einsum("aij,ai->aj", Y, a) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    a_val, r_val = sess.run([a, r]) 
    print("a:", a_val, "\nr:", r_val) 
+0

糾正我,如果我錯了,但我相信注意機制將需要編碼器和解碼器輸出工作。我不認爲這只是將softmax應用到編碼器網絡的輸出上。 – YellowPillow