0
從例子https://github.com/timediv/speechT,我試圖適應與LSTM網絡一起使用,但失敗請幫忙。我嘗試了許多組合,但我總是有錯誤,即輸入必須是序列或其他。爲了語音識別的目的,我需要實現LSTM網絡,並且在嘗試了幾個星期後,我仍然陷入了編碼問題。任何人都可以幫助我提供使用LSTM網絡的例子,樣本會很好。用LSTM網絡重複使用speechT例子
class InputBatchLoader(BaseInputLoader):
def __init__(self, input_size, batch_size, data_generator_creator, max_steps=None):
super().__init__(input_size)
self.batch_size = batch_size
self.data_generator_creator = data_generator_creator
self.steps_left = max_steps
with tf.device("/cpu:0"):
with tf.device("/cpu:0"):
# Define input and label placeholders
self.inputs = tf.placeholder(tf.float32, [batch_size, None, input_size], name='inputs')
self.sequence_lengths = tf.placeholder(tf.int32, [batch_size], name='sequence_lengths')
self.labels = tf.sparse_placeholder(tf.int32, name='labels')
# Queue for inputs and labels
self.queue = tf.FIFOQueue(dtypes=[tf.float32, tf.int32, tf.string],
capacity=100)
# queues do not support sparse tensors yet, we need to serialize...
serialized_labels = tf.serialize_many_sparse(self.labels)
self.enqueue_op = self.queue.enqueue([self.inputs,
self.sequence_lengths,
serialized_labels])
class Wav2LetterLSTMModel(SpeechModel): #Add Sep 14, 2017 to create LSTM model
def __init__(self, input_loader: BaseInputLoader, input_size: int, num_classes: int):
super().__init__(input_loader, input_size, num_classes)
def _create_network(self, num_classes):
cellsize = 256
num_layers = 3
inputs = self.inputs
lstm_cell = rnn.BasicLSTMCell(cellsize, forget_bias=1.0)
outputs, states = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)
return tf.transpose(outputs, (1, 0, 2))
def create_default_model(flags, input_size: int, speech_input: BaseInputLoader) -> SpeechModel:
model = Wav2LetterLSTMModel(input_loader=speech_input,
input_size=input_size,
num_classes=speecht.vocabulary.SIZE + 1) #Add Sep 14, 2017, to use LSTM model
# TODO how can we restore only selected variables so we do not need to always create the full network?
if flags.command == 'train':
model.add_training_ops(learning_rate=flags.learning_rate,
learning_rate_decay_factor=flags.learning_rate_decay_factor,
max_gradient_norm=flags.max_gradient_norm,
momentum=flags.momentum)
model.add_decoding_ops()
elif flags.command == 'export':
model.add_training_ops()
model.add_decoding_ops()
else:
model.add_training_ops()
model.add_decoding_ops(language_model=flags.language_model,
lm_weight=flags.lm_weight,
word_count_weight=flags.word_count_weight,
valid_word_count_weight=flags.valid_word_count_weight)
model.finalize(log_dir=flags.log_dir,
run_name=flags.run_name,
run_type=flags.run_type)
return model
其實,我們如何養活self.inputs = tf.placeholder(tf.float32,[batch_size時,無,input_size],名稱= '輸入')到輸出,states = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype = tf.float32) –