1
我正在使用tensorflow的ctc_cost
和ctc_greedy_decoder
。當我訓練最小化模型ctc_cost
的成本時,但是當我解碼它總是沒有投入任何東西。這有可能發生嗎?我的代碼如下。tensorflow - CTC丟失減少,但解碼器輸出爲空
我想知道我是否正確預處理數據。我預測在給定fbank特徵幀上的手機序列號。有48部電話(48班),每個框架有69個功能。我將num_classes
設置爲49,因此邏輯將具有尺寸(max_time_steps, num_samples, 49)
。而對於我的稀疏張量,我的值範圍從0到47(48保留空白)。我從未在數據中添加空白,我認爲我不應該這樣做? (我應該做那樣的事情嗎?)
當訓練時,每次迭代和時期後成本都會下降,但編輯距離永遠不會減少。事實上,它保持在1,因爲解碼器幾乎總是預測和排空序列。有什麼我做錯了嗎?
graph = tf.Graph()
with graph.as_default():
inputs = tf.placeholder(tf.float32, [None, None, num_features])
targets = tf.sparse_placeholder(tf.int32)
seq_len = tf.placeholder(tf.int32, [None])
seq_len_t = tf.placeholder(tf.int32, [None])
cell = tf.contrib.rnn.LSTMCell(num_hidden)
stack = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
outputs, _ = tf.nn.dynamic_rnn(stack, inputs, seq_len, dtype=tf.float32)
outputs, _ = tf.nn.dynamic_rnn(stack, inputs, seq_len, dtype=tf.float32)
input_shape = tf.shape(inputs)
outputs = tf.reshape(outputs, [-1, num_hidden])
W = tf.Variable(tf.truncated_normal([num_hidden,
num_classes],
stddev=0.1))
b = tf.Variable(tf.constant(0., shape=[num_classes]))
logits = tf.matmul(outputs, W) + b
logits = tf.reshape(logits, [input_shape[0], -1, num_classes])
logits = tf.transpose(logits, (1, 0, 2))
loss = tf.nn.ctc_loss(targets, logits, seq_len)
cost = tf.reduce_mean(loss)
decoded, log_probabilities = tf.nn.ctc_greedy_decoder(logits, seq_len, merge_repeated=True)
optimizer = tf.train.MomentumOptimizer(initial_learning_rate, 0.1).minimize(cost)
err = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0],tf.int32), targets))
saver = tf.train.Saver()
with tf.Session(graph=graph) as session:
X, Y, ids, seq_length, label_to_int, int_to_label = get_data('train')
session.run(tf.global_variables_initializer())
print(seq_length)
num_batches = len(X)//batch_size + 1
for epoch in range(epochs):
print ('epoch'+str(epoch))
for batch in range(num_batches):
input_X, target_input, seq_length_X = get_next_batch(batch,X, Y ,seq_length,batch_size)
feed = {inputs: input_X ,
targets: target_input,
seq_len: seq_length_X}
print ('epoch'+str(epoch))
_, print_cost, print_er = session.run([optimizer, cost, err], feed_dict = feed)
print('epoch '+ str(epoch)+' batch '+str(batch)+ ' cost: '+str(print_cost)+' er: '+str(print_er))
save_path = saver.save(session, '/tmp/model.ckpt')
print('model saved')
X_t, ids_t, seq_length_t = get_data('test')
feed_t = {inputs: X_t, seq_len: seq_length_t}
print(X.shape)
print(X_t.shape)
print(type(seq_length_t[0]))
de, lo = session.run([decoded[0], log_probabilities],feed_dict = feed_t)
with open('predict.pickle', 'wb') as f:
pickle.dump((de, lo), f)
是網絡完全培訓(培訓錯誤停滯)嗎?由於空練習通常在訓練開始時遇到。例如。搜索「CTC中有趣的空白標籤」。不,你不必爲目標化妝品添加空白。這些空白僅供(CTC)內部使用。 – Harry