2017-02-06 66 views
2

有人可以建議我的多維lstm的實施改進?Multidimentional lstm tensorflow

這是非常慢,並使用大量的內存。

class MultiDimentionalLSTMCell(tf.nn.rnn_cell.RNNCell): 
""" 
Adapted from TF's BasicLSTMCell to use Layer Normalization. 
Note that state_is_tuple is always True. 
""" 

def __init__(self, num_units, forget_bias=1.0, activation=tf.nn.tanh): 
    self._num_units = num_units 
    self._forget_bias = forget_bias 
    self._activation = activation 

@property 
def state_size(self): 
    return tf.nn.rnn_cell.LSTMStateTuple(self._num_units, self._num_units) 

@property 
def output_size(self): 
    return self._num_units 

def __call__(self, inputs, state, scope=None): 
    """Long short-term memory cell (LSTM). 
    @param: imputs (batch,n) 
    @param state: the states and hidden unit of the two cells 
    """ 
    with tf.variable_scope(scope or type(self).__name__): 
     c1,c2,h1,h2 = state 

     # change bias argument to False since LN will add bias via shift 
     concat = tf.nn.rnn_cell._linear([inputs, h1, h2], 5 * self._num_units, False) 

     i, j, f1, f2, o = tf.split(1, 5, concat) 

     # add layer normalization to each gate 
     #i = ln(i, scope = 'i/') 
     #j = ln(j, scope = 'j/') 
     #f1 = ln(f1, scope = 'f1/') 
     #f2 = ln(f2, scope = 'f2/') 
     #o = ln(o, scope = 'o/') 

     new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) + 
       c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) * 
       self._activation(j)) 

     # add layer_normalization in calculation of new hidden state 
     new_h = self._activation(ln(new_c, scope = 'new_h/')) * tf.nn.sigmoid(o) 
     new_state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h) 

     return new_h, new_state 


def MultidimentionalRNN(rnn_size,input_data,sh,dims=None,scopeN="layer1"): 
    """Implements naive multidimentional recurent neural networks 

    @param rnn_size: the hidden units 
    @param input_data: the data to process of shape [batch,h,w,chanels] 
    @param sh: [heigth,width] of the windows 
    @param dims: dimentions to reverse the input data,eg. 
     dims=[False,True,True,False] => true means reverse dimention 
    @param scopeN : the scope 

    returns [batch,h/sh[0],w/sh[1],chanels*sh[0]*sh[1]] the output of the lstm 
    """ 
    with tf.variable_scope("MultiDimentionalLSTMCell-"+scopeN): 
     cell = MultiDimentionalLSTMCell(rnn_size) 

    shape = input_data.get_shape().as_list() 
    # add paddings 
    #todos: 
    #y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b) 
    if shape[1]%sh[0] != 0: 
     offset = tf.zeros([shape[0], sh[0]-(shape[1]%sh[0]), shape[2], shape[3]]) 
     input_data = tf.concat(1,[input_data,offset]) 
     shape = input_data.get_shape().as_list() 
    if shape[2]%sh[1] != 0: 
     offset = tf.zeros([shape[0], shape[1], sh[1]-(shape[2]%sh[1]), shape[3]]) 
     input_data = tf.concat(2,[input_data,offset]) 
     shape = input_data.get_shape().as_list() 

    w,h = int(shape[1]/sh[0]),int(shape[2]/sh[1]) 
    features = sh[1]*sh[0]*shape[3] 
    batch_size = shape[0] 

    x = tf.reshape(input_data, [batch_size,h,w, features]) 
    if dims is not None: 
     x = tf.reverse(x, dims) 
    x = tf.transpose(x, [1,2,0,3]) 
    x = tf.reshape(x, [-1, features]) 
    x = tf.split(0, h*w, x) 
    states = [] 
    outputs = [] 
    #todo: add seq_len 2D (have to add paddings after) 
    #use tf.get_variable() 
    #result = tf.while_loop(condition, body, [x]) 
    with tf.variable_scope("MultiDimentionalRnn-"+scopeN) as scope: 
     for i,inputs in enumerate(x): 
       #stateUp = tf.cond(i>=w, lambda: states[i-w], lambda: cell.zero_state(batch_size, tf.float32)) 
       stateUp = states[i-w] if i>=w else cell.zero_state(batch_size, tf.float32) 
       #stateLast = tf.cond(i%w>0, lambda: states[i-1], lambda: cell.zero_state(batch_size, tf.float32)) 
       stateLast = states[i-1] if i%w>0 else cell.zero_state(batch_size, tf.float32) 

       currentState = stateUp[0],stateLast[0],stateUp[1],stateLast[1] 
       out , state = cell(inputs,currentState)      
       states.append(state) 
       outputs.append(out) 
       scope.reuse_variables() 
    outputs = tf.pack(outputs, axis=0) 

    y = tf.reshape(outputs, [h,w,batch_size,rnn_size]) 
    y = tf.transpose(y, [2,0,1,3]) 
    if dims is not None: 
     y = tf.reverse(y, dims) 

    return y 
+0

我有同樣的問題。我認爲節省內存和空間的正確方法是使用tf.while動態地構建圖形。對於1D RNN來說,這很簡單,但是我在構建2D等價物時遇到了問題。 – Ben

+0

我問了一個關於我的嘗試的問題,你可能會或可能找不到有用的[這裏](http://stackoverflow.com/questions/42313828/dynamic-graphs-in-tensorflow) – Ben

回答

2
def ln(tensor, scope = None, epsilon = 1e-5): 
    """ Layer normalizes a 2D tensor along its second axis """ 
    assert(len(tensor.get_shape()) == 2) 
    m, v = tf.nn.moments(tensor, [1], keep_dims=True) 
    if not isinstance(scope, str): 
     scope = '' 
    with tf.variable_scope(scope + 'layer_norm'): 
     scale = tf.get_variable('scale', 
           shape=[tensor.get_shape()[1]], 
           initializer=tf.constant_initializer(1)) 
     shift = tf.get_variable('shift', 
           shape=[tensor.get_shape()[1]], 
           initializer=tf.constant_initializer(0)) 
    LN_initial = (tensor - m)/tf.sqrt(v + epsilon) 

    return LN_initial * scale + shift 


class MultiDimentionalLSTMCell(tf.nn.rnn_cell.RNNCell): 
    """ 
    Adapted from TF's BasicLSTMCell to use Layer Normalization. 
    Note that state_is_tuple is always True. 
    """ 

    def __init__(self, num_units, forget_bias=0.0, activation=tf.nn.tanh): 
     self._num_units = num_units 
     self._forget_bias = forget_bias 
     self._activation = activation 

    @property 
    def state_size(self): 
     return tf.nn.rnn_cell.LSTMStateTuple(self._num_units, self._num_units) 

    @property 
    def output_size(self): 
     return self._num_units 

    def __call__(self, inputs, state, scope=None): 
     """Long short-term memory cell (LSTM). 
     @param: imputs (batch,n) 
     @param state: the states and hidden unit of the two cells 
     """ 
     with tf.variable_scope(scope or type(self).__name__): 
      c1,c2,h1,h2 = state 

      # change bias argument to False since LN will add bias via shift 
      concat = tf.nn.rnn_cell._linear([inputs, h1, h2], 5 * self._num_units, False) 

      i, j, f1, f2, o = tf.split(1, 5, concat) 

      # add layer normalization to each gate 
      i = ln(i, scope = 'i/') 
      j = ln(j, scope = 'j/') 
      f1 = ln(f1, scope = 'f1/') 
      f2 = ln(f2, scope = 'f2/') 
      o = ln(o, scope = 'o/') 

      new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) + 
        c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) * 
        self._activation(j)) 

      # add layer_normalization in calculation of new hidden state 
      new_h = self._activation(ln(new_c, scope = 'new_h/')) * tf.nn.sigmoid(o) 
      new_state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h) 

      return new_h, new_state 


def multiDimentionalRNN_whileLoop(rnn_size,input_data,sh,dims=None,scopeN="layer1"): 
     """Implements naive multidimentional recurent neural networks 

     @param rnn_size: the hidden units 
     @param input_data: the data to process of shape [batch,h,w,chanels] 
     @param sh: [heigth,width] of the windows 
     @param dims: dimentions to reverse the input data,eg. 
      dims=[False,True,True,False] => true means reverse dimention 
     @param scopeN : the scope 

     returns [batch,h/sh[0],w/sh[1],chanels*sh[0]*sh[1]] the output of the lstm 
     """ 
     with tf.variable_scope("MultiDimentionalLSTMCell-"+scopeN): 
      cell = MultiDimentionalLSTMCell(rnn_size) 

      shape = input_data.get_shape().as_list() 

      if shape[1]%sh[0] != 0: 
       offset = tf.zeros([shape[0], sh[0]-(shape[1]%sh[0]), shape[2], shape[3]]) 
       input_data = tf.concat(1,[input_data,offset]) 
       shape = input_data.get_shape().as_list() 
      if shape[2]%sh[1] != 0: 
       offset = tf.zeros([shape[0], shape[1], sh[1]-(shape[2]%sh[1]), shape[3]]) 
       input_data = tf.concat(2,[input_data,offset]) 
       shape = input_data.get_shape().as_list() 

      h,w = int(shape[1]/sh[0]),int(shape[2]/sh[1]) 
      features = sh[1]*sh[0]*shape[3] 
      batch_size = shape[0] 

      x = tf.reshape(input_data, [batch_size,h,w, features]) 
      if dims is not None: 
       assert dims[0] == False and dims[3] == False 
       x = tf.reverse(x, dims) 
      x = tf.transpose(x, [1,2,0,3]) 
      x = tf.reshape(x, [-1, features]) 
      x = tf.split(0, h*w, x)  

      sequence_length = tf.ones(shape=(batch_size,), dtype=tf.int32)*shape[0] 
      inputs_ta = tf.TensorArray(dtype=tf.float32, size=h*w,name='input_ta') 
      inputs_ta = inputs_ta.unpack(x) 
      states_ta = tf.TensorArray(dtype=tf.float32, size=h*w+1,name='state_ta',clear_after_read=False) 
      outputs_ta = tf.TensorArray(dtype=tf.float32, size=h*w,name='output_ta') 

      states_ta = states_ta.write(h*w, tf.nn.rnn_cell.LSTMStateTuple(tf.zeros([batch_size,rnn_size], tf.float32), 
                 tf.zeros([batch_size,rnn_size], tf.float32))) 
      def getindex1(t,w): 
       return tf.cond(tf.less_equal(tf.constant(w),t), 
           lambda:t-tf.constant(w), 
           lambda:tf.constant(h*w)) 
      def getindex2(t,w): 
       return tf.cond(tf.less(tf.constant(0),tf.mod(t,tf.constant(w))), 
           lambda:t-tf.constant(1), 
           lambda:tf.constant(h*w)) 

      time = tf.constant(0) 

      def body(time, outputs_ta, states_ta): 
       constant_val = tf.constant(0) 
       stateUp = tf.cond(tf.less_equal(tf.constant(w),time), 
            lambda: states_ta.read(getindex1(time,w)), 
            lambda: states_ta.read(h*w)) 
       stateLast = tf.cond(tf.less(constant_val,tf.mod(time,tf.constant(w))), 
            lambda: states_ta.read(getindex2(time,w)), 
            lambda: states_ta.read(h*w)) 

       currentState = stateUp[0],stateLast[0],stateUp[1],stateLast[1] 
       out , state = cell(inputs_ta.read(time),currentState) 
       outputs_ta = outputs_ta.write(time,out) 
       states_ta = states_ta.write(time,state) 
       return time + 1, outputs_ta, states_ta 

      def condition(time,outputs_ta,states_ta): 
       return tf.less(time , tf.constant(h*w)) 

      result , outputs_ta, states_ta = tf.while_loop(condition, body, [time,outputs_ta,states_ta] 
                  ,parallel_iterations=1) 


      outputs = outputs_ta.pack() 
      states = states_ta.pack() 

      y = tf.reshape(outputs, [h,w,batch_size,rnn_size]) 
      y = tf.transpose(y, [2,0,1,3]) 
      if dims is not None: 
       y = tf.reverse(y, dims) 

      return y,states 
+0

它涉及到[tensorflow-multi- dimensional-lstm project](https://github.com/philipperemy/tensorflow-multi-dimensional-lstm)?在自述文件中有 – yvs

+0

有信用語句 –

+0

我已閱讀它。但是我想知道這個代碼是該項目的一部分還是另一個。 – yvs