我正在使用tensorlayer的concatLayer來連接兩個輸入,一個是嵌入,另一個是其他附加輸入。 tl.layers.ConcatLayer總是運行到TypeError中:「Expected int32,獲取包含'_Message'類型的張量的列表。」來自tensorlayer的TypeError由於tf.concat導致的ConcatLayer API更改
似乎tf.concat()API的變化可能是一個原因,但我使用TF 1.2.0 + tensorlayer 1.5.1(蟒蛇2.7.13 |蟒蛇4.3.0)
有人能幫忙嗎?由於-Wei
網絡設計:
emb_net = tl.layers.EmbeddingInputlayer(
inputs = x,
vocabulary_size = VOCAB_SIZE,
embedding_size = FLAGS.embedding_size,
E_init = tf.random_uniform_initializer(
-FLAGS.init_scale, FLAGS.init_scale),
name = 'embedding')
word_bound = tl.layers.InputLayer(
inputs = x_wb,
name = 'word_boundary')
network = tl.layers.ConcatLayer(
layer = [emb_net, word_bound],
concat_dim = 1,
name = 'concat_layer')
錯誤消息:
File "./tensorlayer_lstm_classifier.py", line 147, in do_training
reuse=None)
File "./tensorlayer_lstm_classifier.py", line 53, in inference
name = 'concat_layer')
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorlayer/layers.py", line 4717, in __init__
self.outputs = tf.concat(concat_dim, self.inputs, name=name)
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1043, in concat
dtype=dtypes.int32).get_shape(
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 676, in convert_to_tensor
as_ref=False)
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 741, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 113, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 102, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 374, in make_tensor_proto
_AssertCompatible(values, dtype)
File "/Users/lin/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 302, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).__name__))
TypeError: Expected int32, got list containing Tensors of type '_Message' instead.
軸輸入已更正爲cocat_dim = 2,因爲concat輸入爲三維 –