diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 96bf8d17f71c06cb073510bcead07d1fa21db2d4..6297f1854385ca8913f3e2e73d5b66d87ba188eb 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -610,10 +610,11 @@ class U2Tester(U2Trainer): input_spec=[ paddle.static.InputSpec( shape=[1, decoder_max_time], dtype='int32'), # tgt - paddle.static.InputSpec(shape=[1], dtype='int32'), # tgt_len + paddle.static.InputSpec( + shape=[1, decoder_max_time], dtype='bool'), # tgt_mask paddle.static.InputSpec( shape=[1, encoder_max_time, encoder_model_size], - dtype='int32'), # encoder_out + dtype='float32'), # encoder_out ]) logger.info(f"Export code: {static_model}") diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index c3d93d8a67349b9a705f5f6073f24c64075ff973..c6a33ad589405ba59916ed661b789e8d8d1b94a4 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -927,3 +927,32 @@ class U2InferModel(U2Model): decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, simulate_streaming=simulate_streaming) + + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_masks: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining, (B, U) + hyps_masks (paddle.Tensor): length of each hyp in hyps, (B, U) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) + Returns: + paddle.Tensor: decoder output, (B, V) + """ + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_masks.shape[0] == num_hyps + # encoder_out = encoder_out.repeat(num_hyps, 1, 1) + encoder_out = encoder_out.tile([num_hyps, 1, 1]) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps, + hyps_masks) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index c2bcbb48a38c4b277368f013da0b97c413663fb7..e3ab80a239f1aa576faccca665360a441498824c 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -185,3 +185,45 @@ class TransformerDecoder(nn.Module): y = paddle.log_softmax(self.output_layer(y), axis=-1) return y, new_cache + + def export( + self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + ys_in_pad: paddle.Tensor, + ys_in_mask: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_mask: input mask of this batch (batch, maxlen_out) + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = ys_in_mask.unsqueeze(1) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) + # tgt_mask: (B, L, L) + # TODO(Hui Zhang): not support & for tensor + # tgt_mask = tgt_mask & m + tgt_mask = tgt_mask.logical_and(m) + + x, _ = self.embed(tgt) + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + + # TODO(Hui Zhang): reduce_sum not support bool type + # olens = tgt_mask.sum(1) + olens = tgt_mask.astype(paddle.int).sum(1) + return x, olens