diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 9337a5bbf3e9400199ed3ca648d78689850082a9..8091dbe14994ef491dea662267287fec400f0bac 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -399,6 +399,7 @@ class U2BaseModel(nn.Module): assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 batch_size = speech.shape[0] + # Let's assume B = batch_size # encoder_out: (B, maxlen, encoder_dim) # encoder_mask: (B, 1, Tmax) @@ -410,10 +411,12 @@ class U2BaseModel(nn.Module): # encoder_out_lens = encoder_mask.squeeze(1).sum(1) encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] return hyps @@ -449,6 +452,7 @@ class U2BaseModel(nn.Module): batch_size = speech.shape[0] # For CTC prefix beam search, we only support batch_size=1 assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size # 1. Encoder forward and get CTC score encoder_out, encoder_mask = self._forward_encoder( @@ -458,7 +462,9 @@ class U2BaseModel(nn.Module): maxlen = encoder_out.shape[1] ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain cur_hyps = [(tuple(), (0.0, -float('inf')))] # 2. CTC beam search step by step for t in range(0, maxlen): @@ -498,6 +504,7 @@ class U2BaseModel(nn.Module): key=lambda x: log_add(list(x[1])), reverse=True) cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] return hyps, encoder_out @@ -561,12 +568,13 @@ class U2BaseModel(nn.Module): batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 assert batch_size == 1 - # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + + # len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim) hyps, encoder_out = self._ctc_prefix_beam_search( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) - assert len(hyps) == beam_size + hyps_pad = pad_sequence([ paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) for hyp in hyps @@ -576,23 +584,29 @@ class U2BaseModel(nn.Module): dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() + # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] - # add ctc score + # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: best_score = score diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 0aedea7480a72f627e3f313b04ecfe0add7ad4de..af782fb53fa5459e93bb6e99b2a4202bdb4121a4 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer): xs, pos_emb, _ = self.embed( xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) + if subsampling_cache is not None: cache_size = subsampling_cache.size(1) #T xs = paddle.cat((subsampling_cache, xs), dim=1) else: cache_size = 0 + + # only used when using `RelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( offset=offset - cache_size, size=xs.size(1)) @@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer): # Real mask for transformer/conformer layers masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) - masks = masks.unsqueeze(1) #[B=1, C=1, T] + masks = masks.unsqueeze(1) #[B=1, L'=1, T] r_elayers_output_cache = [] r_conformer_cnn_cache = [] for i, layer in enumerate(self.encoders):