From a7285d5e1fe1375e42843b000af6c25bd047a1ef Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 21 Aug 2017 21:54:28 +0800 Subject: [PATCH] Add GRU support. --- deep_speech_2/demo_server.py | 6 ++++ deep_speech_2/evaluate.py | 6 ++++ deep_speech_2/infer.py | 6 ++++ deep_speech_2/layer.py | 64 ++++++++++++++++++++++++++++++++---- deep_speech_2/model.py | 9 ++--- deep_speech_2/train.py | 8 ++++- deep_speech_2/tune.py | 6 ++++ 7 files changed, 93 insertions(+), 12 deletions(-) diff --git a/deep_speech_2/demo_server.py b/deep_speech_2/demo_server.py index c7e7e94a..60d97239 100644 --- a/deep_speech_2/demo_server.py +++ b/deep_speech_2/demo_server.py @@ -66,6 +66,11 @@ parser.add_argument( default=512, type=int, help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gru", + default=True, + type=bool, + help="Use GRU or simple RNN. (default: %(default)s)") parser.add_argument( "--use_gpu", default=True, @@ -199,6 +204,7 @@ def start_server(): num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, pretrained_model_path=args.model_filepath) # prepare ASR inference handler diff --git a/deep_speech_2/evaluate.py b/deep_speech_2/evaluate.py index 82dcec3c..2f87abbd 100644 --- a/deep_speech_2/evaluate.py +++ b/deep_speech_2/evaluate.py @@ -38,6 +38,11 @@ parser.add_argument( default=512, type=int, help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gru", + default=True, + type=bool, + help="Use GRU or simple RNN. (default: %(default)s)") parser.add_argument( "--use_gpu", default=True, @@ -142,6 +147,7 @@ def evaluate(): num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, pretrained_model_path=args.model_filepath) error_rate_func = cer if args.error_rate_type == 'cer' else wer diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 43643cde..91b08932 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -33,6 +33,11 @@ parser.add_argument( default=512, type=int, help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gru", + default=True, + type=bool, + help="Use GRU or simple RNN. (default: %(default)s)") parser.add_argument( "--use_gpu", default=True, @@ -143,6 +148,7 @@ def infer(): num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, pretrained_model_path=args.model_filepath) result_transcripts = ds2_model.infer_batch( infer_data=infer_data, diff --git a/deep_speech_2/layer.py b/deep_speech_2/layer.py index 3b492645..1b1a5810 100644 --- a/deep_speech_2/layer.py +++ b/deep_speech_2/layer.py @@ -57,7 +57,7 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, act): # input-hidden weights shared across bi-direcitonal rnn. input_proj = paddle.layer.fc( input=input, size=size, act=paddle.activation.Linear(), bias_attr=False) - # batch norm is only performed on input-state projection + # batch norm is only performed on input-state projection input_proj_bn = paddle.layer.batch_norm( input=input_proj, act=paddle.activation.Linear()) # forward and backward in time @@ -68,6 +68,38 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, act): return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn]) +def bidirectional_gru_bn_layer(name, input, size, act): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: LayerOutput + :param size: Number of RNN cells. + :type size: int + :param act: Activation type. + :type act: BaseActivation + :return: Bidirectional simple rnn layer. + :rtype: LayerOutput + """ + # input-hidden weights shared across bi-direcitonal rnn. + input_proj = paddle.layer.fc( + input=input, + size=size * 3, + act=paddle.activation.Linear(), + bias_attr=False) + # batch norm is only performed on input-state projection + input_proj_bn = paddle.layer.batch_norm( + input=input_proj, act=paddle.activation.Linear()) + # forward and backward in time + forward_gru = paddle.layer.grumemory( + input=input_proj_bn, act=act, reverse=False) + backward_gru = paddle.layer.grumemory( + input=input_proj_bn, act=act, reverse=True) + return paddle.layer.concat(input=[forward_gru, backward_gru]) + + def conv_group(input, num_stacks): """Convolution group with stacked convolution layers. @@ -83,7 +115,7 @@ def conv_group(input, num_stacks): filter_size=(11, 41), num_channels_in=1, num_channels_out=32, - stride=(3, 2), + stride=(2, 2), padding=(5, 20), act=paddle.activation.BRelu()) for i in xrange(num_stacks - 1): @@ -100,7 +132,7 @@ def conv_group(input, num_stacks): return conv, output_num_channels, output_height -def rnn_group(input, size, num_stacks): +def rnn_group(input, size, num_stacks, use_gru): """RNN group with stacked bidirectional simple RNN layers. :param input: Input layer. @@ -109,13 +141,25 @@ def rnn_group(input, size, num_stacks): :type size: int :param num_stacks: Number of stacked rnn layers. :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool :return: Output layer of the RNN group. :rtype: LayerOutput """ output = input for i in xrange(num_stacks): - output = bidirectional_simple_rnn_bn_layer( - name=str(i), input=output, size=size, act=paddle.activation.BRelu()) + if use_gru: + output = bidirectional_gru_bn_layer( + name=str(i), + input=output, + size=size, + act=paddle.activation.BRelu()) + else: + output = bidirectional_simple_rnn_bn_layer( + name=str(i), + input=output, + size=size, + act=paddle.activation.BRelu()) return output @@ -124,7 +168,8 @@ def deep_speech2(audio_data, dict_size, num_conv_layers=2, num_rnn_layers=3, - rnn_size=256): + rnn_size=256, + use_gru=True): """ The whole DeepSpeech2 model structure (a simplified version). @@ -140,6 +185,8 @@ def deep_speech2(audio_data, :type num_rnn_layers: int :param rnn_size: RNN layer size (number of RNN cells). :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool :return: A tuple of an output unnormalized log probability layer ( before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput @@ -157,7 +204,10 @@ def deep_speech2(audio_data, block_y=conv_group_height) # rnn group rnn_group_output = rnn_group( - input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) + input=conv2seq, + size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru) fc = paddle.layer.fc( input=rnn_group_output, size=dict_size + 1, diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py index 99412e59..eec971c0 100644 --- a/deep_speech_2/model.py +++ b/deep_speech_2/model.py @@ -30,9 +30,9 @@ class DeepSpeech2Model(object): """ def __init__(self, vocab_size, num_conv_layers, num_rnn_layers, - rnn_layer_size, pretrained_model_path): + rnn_layer_size, use_gru, pretrained_model_path): self._create_network(vocab_size, num_conv_layers, num_rnn_layers, - rnn_layer_size) + rnn_layer_size, use_gru) self._create_parameters(pretrained_model_path) self._inferer = None self._loss_inferer = None @@ -226,7 +226,7 @@ class DeepSpeech2Model(object): gzip.open(model_path)) def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers, - rnn_layer_size): + rnn_layer_size, use_gru): """Create data layers and model network.""" # paddle.data_type.dense_array is used for variable batch input. # The size 161 * 161 is only an placeholder value and the real shape @@ -243,4 +243,5 @@ class DeepSpeech2Model(object): dict_size=vocab_size, num_conv_layers=num_conv_layers, num_rnn_layers=num_rnn_layers, - rnn_size=rnn_layer_size) + rnn_size=rnn_layer_size, + use_gru=use_gru) diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 262d8bf0..8e95d7bc 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -37,9 +37,14 @@ parser.add_argument( help="RNN layer number. (default: %(default)s)") parser.add_argument( "--rnn_layer_size", - default=512, + default=1280, type=int, help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gru", + default=True, + type=bool, + help="Use GRU or simple RNN. (default: %(default)s)") parser.add_argument( "--adam_learning_rate", default=5e-4, @@ -170,6 +175,7 @@ def train(): num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, pretrained_model_path=args.init_model_path) ds2_model.train( train_batch_reader=train_batch_reader, diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index 328d67a1..8a9b5b61 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -34,6 +34,11 @@ parser.add_argument( default=512, type=int, help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gru", + default=True, + type=bool, + help="Use GRU or simple RNN. (default: %(default)s)") parser.add_argument( "--use_gpu", default=True, @@ -158,6 +163,7 @@ def tune(): num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, pretrained_model_path=args.model_filepath) # create grid for search -- GitLab