2017-02-27 70 views
1

我想在TF v1.0中運行下面的代碼,它會引發錯誤。我已經創建了一個LSTM,然後定義的狀態可變通過它,以便計算輸出給lstm_cell,但狀態不能被初始化:如何在TensorFlow中初始化LSTM單元?

lstm_cell = tf.contrib.rnn.BasicLSTMCell(10) 
# Initial state of the LSTM memory. 
state = tf.zeros([20, lstm_cell.state_size]) 
outputs, states = lstm_cell(x , state) 

這回溯:

ValueError        Traceback (most recent call last) 
<ipython-input-82-4a23eee1acf4> in <module>() 
     1 lstm_cell = tf.contrib.rnn.BasicLSTMCell(10) 
     2 # Initial state of the LSTM memory. 
----> 3 state = tf.zeros([20, lstm_cell.state_size]) 
     4 outputs, states = lstm_cell(x , state) 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.pyc in zeros(shape, dtype, name) 
    1370  output = constant(zero, shape=shape, dtype=dtype, name=name) 
    1371  except (TypeError, ValueError): 
-> 1372  shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape") 
    1373  output = fill(shape, constant(zero, dtype=dtype), name=name) 
    1374 assert output.dtype.base_dtype == dtype 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in convert_to_tensor(value, dtype, name, preferred_dtype) 
    649  name=name, 
    650  preferred_dtype=preferred_dtype, 
--> 651  as_ref=False) 
    652 
    653 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype) 
    714 
    715   if ret is None: 
--> 716   ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 
    717 
    718   if ret is NotImplemented: 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.pyc in _constant_tensor_conversion_function(v, dtype, name, as_ref) 
    174           as_ref=False): 
    175 _ = as_ref 
--> 176 return constant(v, dtype=dtype, name=name) 
    177 
    178 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.pyc in constant(value, dtype, shape, name, verify_shape) 
    163 tensor_value = attr_value_pb2.AttrValue() 
    164 tensor_value.tensor.CopyFrom(
--> 165  tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) 
    166 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) 
    167 const_tensor = g.create_op(

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.pyc in make_tensor_proto(values, dtype, shape, verify_shape) 
    366  else: 
    367  _AssertCompatible(values, dtype) 
--> 368  nparray = np.array(values, dtype=np_dt) 
    369  # check to them. 
    370  # We need to pass in quantized values as tuples, so don't apply the shape 

ValueError: setting an array element with a sequence. 

回答

0

這是因爲在Tensorflow的最新版本中,state_size屬性的默認BasicLSTMCell的返回值是LSTMStateTuple(Python Tuple)。

如果您檢查source code,您可以看到在元組的兩個元素(它們在先前版本中沿着同一軸連接)返回相同數量的單位,並且應該在初始化單元狀態時考慮。

因此,這應該做的伎倆:

state = tf.zeros([20, lstm_cell.state_size[0]*2) 
相關問題