From 2112ee1ecc6dc6ca2ea0aa5e72223d30de0bf60b Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 5 Jul 2021 04:08:29 +0000 Subject: [PATCH] fix decoder mask --- deepspeech/exps/u2/model.py | 21 +++++++++++++-------- deepspeech/modules/decoder.py | 7 ++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index add22591..96bf8d17 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -601,17 +601,22 @@ class U2Tester(U2Trainer): # assert isinstance(input_spec, list), type(input_spec) infer_model.eval() #static_model = paddle.jit.to_static(infer_model., input_spec=input_spec) - + + decoder_max_time = 100 + encoder_max_time = None + encoder_model_size = 256 static_model = paddle.jit.to_static( - infer_model.forward_attention_decoder, + infer_model.forward_attention_decoder, input_spec=[ - paddle.static.InputSpec(shape=[1, None],dtype='int32'), - paddle.static.InputSpec(shape=[1],dtype='int32'), - paddle.static.InputSpec(shape=[1, None, 256],dtype='int32'), - ] - ) + paddle.static.InputSpec( + shape=[1, decoder_max_time], dtype='int32'), # tgt + paddle.static.InputSpec(shape=[1], dtype='int32'), # tgt_len + paddle.static.InputSpec( + shape=[1, encoder_max_time, encoder_model_size], + dtype='int32'), # encoder_out + ]) logger.info(f"Export code: {static_model}") - + paddle.jit.save(static_model, self.args.export_path) def run_export(self): diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index b4eb46f2..c2bcbb48 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -120,7 +120,7 @@ class TransformerDecoder(nn.Module): """ tgt = ys_in_pad # tgt_mask: (B, 1, L) - tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) + tgt_mask = make_non_pad_mask(ys_in_lens).unsqueeze(1) # m: (1, L, L) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) # tgt_mask: (B, L, L) @@ -164,19 +164,24 @@ class TransformerDecoder(nn.Module): y.shape` is (batch, token) """ x, _ = self.embed(tgt) + new_cache = [] for i, decoder in enumerate(self.decoders): if cache is None: c = None else: c = cache[i] + x, tgt_mask, memory, memory_mask = decoder( x, tgt_mask, memory, memory_mask, cache=c) new_cache.append(x) + if self.normalize_before: y = self.after_norm(x[:, -1]) else: y = x[:, -1] + if self.use_output_layer: y = paddle.log_softmax(self.output_layer(y), axis=-1) + return y, new_cache -- GitLab