2017-06-16 106 views
0

我曾希望實現keras中的PointNet(https://arxiv.org/pdf/1612.00593.pdf)的變體,但我無法重複上下文向量(g)的次數可變,所以我可以將它連接起來與前一層缺少上下文(前)的行。我嘗試了Repeat()和keras.backend.Tile()。將張量與keras中的向量合併爲一個向量

input = Input(shape=(None,3)) 
x = TimeDistributed(Dense(128, activation = 'relu'))(input) 
pre = TimeDistributed(Dense(256, activation = 'relu'))(x) 
g = GlobalMaxPooling1D()(pre) 
x = Lambda(merge_on_single, output_shape=(None,512))([pre,g]) 
print(x.shape) 

這是我想出的lambda定義。

def merge_on_single(v): 
#v[0] is variable length tensor, v[1] is the single vector 

return Concatenate()([K.repeat(v[1],K.get_variable_shape(v[0])),v[0]]) 

但出現以下錯誤:

類型錯誤:在列表張量傳遞給「包」作品的「價值」有類型[INT32,INT32]並不都匹配。

UPDATE:

所以我能得到的層不是做給錯誤如下:

input = Input(shape=(None,3)) 

num_point = K.placeholder(input.get_shape()[1].value, dtype=tf.int32) 

#first global feature layer 
x = TimeDistributed(Dense(512, activation = 'relu'))(input) 
x = TimeDistributed(Dense(256, activation = 'relu'))(x) 
g = GlobalMaxPooling1D()(x) 
g = K.reshape(g,(-1,1,256)) 
g = K.tile(x, [1,num_point,1]) 
concat_feat = K.concatenate([x, g]) 

,但現在,我得到以下錯誤:

AttributeError: 'Tensor' object has no attribute '_keras_history' 

回答

0

我懷疑罪魁禍首是K.get_variable_shape(v[0])。由於v[0]的類型爲int32(按照您的錯誤指定),因此當您獲取形狀時,它將返回無。連接要求所有輸入都是相同的類型。