From 5659bd23867466032300ffad8547ccd14cea8396 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 6 Apr 2021 07:31:15 +0000 Subject: [PATCH] add u2 model --- deepspeech/__init__.py | 34 ++++++ deepspeech/models/u2.py | 183 +++++++++++++++++------------- deepspeech/modules/subsampling.py | 2 +- deepspeech/utils/utility.py | 9 +- 4 files changed, 148 insertions(+), 80 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5a5a06b4..5d840148 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -78,7 +78,32 @@ if not hasattr(paddle, 'cat'): "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat + ########### hcak paddle.Tensor ############# +def item(x: paddle.Tensor): + if x.dtype == paddle.fluid.core_avx.VarDesc.VarType.FP32: + return float(x) + else: + raise ValueError("not support") + + +if not hasattr(paddle.Tensor, 'item'): + logger.warn( + "override item of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.item = item + + +def func_long(x: paddle.Tensor): + return paddle.cast(x, paddle.long) + + +if not hasattr(paddle.Tensor, 'long'): + logger.warn( + "override long of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.long = func_long + if not hasattr(paddle.Tensor, 'numel'): logger.warn( "override numel of paddle.Tensor if exists or register, remove this when fixed!" @@ -247,6 +272,15 @@ if not hasattr(paddle.Tensor, 'to'): logger.warn("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) + +def func_float(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.float) + + +if not hasattr(paddle.Tensor, 'float'): + logger.warn("register user float to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'float', func_float) + ########### hcak paddle.nn.functional ############# diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index de4b7a08..db923001 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -53,22 +53,21 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank logger = logging.getLogger(__name__) -__all__ = ['U2Model'] +__all__ = ['U2TransformerModel', "U2ConformerModel"] class U2Model(nn.Module): """CTC-Attention hybrid Encoder-Decoder model""" - def __init__( - self, - vocab_size: int, - encoder: TransformerEncoder, - decoder: TransformerDecoder, - ctc: CTCDecoder, - ctc_weight: float=0.5, - ignore_id: int=IGNORE_ID, - lsm_weight: float=0.0, - length_normalized_loss: bool=False, ): + def __init__(self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTCDecoder, + ctc_weight: float=0.5, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -263,51 +262,54 @@ class U2Model(nn.Module): # Stop if all batch and all beam produce eos if end_flag.sum() == running_size: break + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache) + # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + # 2.3 Seconde beam prune: select topk score with history scores = scores + top_k_logp # (B*N, N), broadcast add scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N) scores = scores.view(-1, 1) # (B*N, 1) + # 2.4. Compute base index in top_k_index, # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index - base_k_index = torch.arange( - batch_size, - device=device).view(-1, 1).repeat([1, beam_size]) # (B, N) + base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + [1, beam_size]) # (B, N) base_k_index = base_k_index * beam_size * beam_size best_k_index = base_k_index.view(-1) + offset_k_index.view( -1) # (B*N) # 2.5 Update best hyps - best_k_pred = torch.index_select( - top_k_index.view(-1), dim=-1, index=best_k_index) # (B*N) + best_k_pred = paddle.index_select( + top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) best_hyps_index = best_k_index // beam_size - last_best_k_hyps = torch.index_select( - hyps, dim=0, index=best_hyps_index) # (B*N, i) - hyps = torch.cat( + last_best_k_hyps = paddle.index_select( + hyps, index=best_hyps_index, axis=0) # (B*N, i) + hyps = paddle.cat( (last_best_k_hyps, best_k_pred.view(-1, 1)), dim=1) # (B*N, i+1) # 2.6 Update end flag - end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) + end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) # 3. Select best of best scores = scores.view(batch_size, beam_size) # TODO: length normalization - best_index = torch.argmax(scores, dim=-1).long() - best_hyps_index = best_index + torch.arange( - batch_size, dtype=torch.long, device=device) * beam_size - best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) + best_index = paddle.argmax(scores, axis=-1).long() # (B) + best_hyps_index = best_index + paddle.arange( + batch_size, dtype=paddle.long) * beam_size + best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) best_hyps = best_hyps[:, 1:] return best_hyps @@ -346,8 +348,8 @@ class U2Model(nn.Module): ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) - mask = make_pad_mask(encoder_out_lens) # (B, maxlen) - topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) + pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) + topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] return hyps @@ -360,7 +362,7 @@ class U2Model(nn.Module): decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, simulate_streaming: bool=False, - ) -> Tuple[List[List[int]], paddle.Tensor]: + blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: """ CTC prefix beam search inner implementation Args: speech (paddle.Tensor): (batch, max_len, feat_dim) @@ -374,7 +376,7 @@ class U2Model(nn.Module): simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: - List[List[int]]: nbest results + List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood) paddle.Tensor: encoder output, (1, max_len, encoder_dim), it will be used for rescoring in attention rescoring mode """ @@ -406,7 +408,7 @@ class U2Model(nn.Module): ps = logp[s].item() for prefix, (pb, pnb) in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None - if s == 0: # blank + if s == blank_id: # blank n_pb, n_pnb = next_hyps[prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) next_hyps[prefix] = (n_pb, n_pnb) @@ -491,7 +493,7 @@ class U2Model(nn.Module): """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 - device = speech.device + device = speech.place batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 assert batch_size == 1 @@ -502,22 +504,22 @@ class U2Model(nn.Module): assert len(hyps) == beam_size hyps_pad = pad_sequence([ - paddle.tensor(hyp[0], device=device, dtype=torch.long) + paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) for hyp in hyps ], True, self.ignore_id) # (beam_size, max_hyps_len) - hyps_lens = paddle.tensor( - [len(hyp[0]) for hyp in hyps], device=device, - dtype=torch.long) # (beam_size,) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) - encoder_mask = torch.ones( - beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device) + encoder_mask = paddle.ones( + beam_size, 1, encoder_out.size(1), dtype=paddle.bool) decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) - decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) - decoder_out = decoder_out.cpu().numpy() + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 @@ -609,56 +611,83 @@ class U2Model(nn.Module): hypothesis from ctc prefix beam search and one encoder output Args: hyps (paddle.Tensor): hyps from ctc prefix beam search, already - pad sos at the begining - hyps_lens (paddle.Tensor): length of each hyp in hyps - encoder_out (paddle.Tensor): corresponding encoder output + pad sos at the begining, (B, T) + hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) Returns: - paddle.Tensor: decoder output + paddle.Tensor: decoder output, (B, L) """ assert encoder_out.size(0) == 1 num_hyps = hyps.size(0) assert hyps_lens.size(0) == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) - encoder_mask = torch.ones( - num_hyps, - 1, - encoder_out.size(1), - dtype=torch.bool, - device=encoder_out.device) - decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps, - hyps_lens) # (num_hyps, max_hyps_len, vocab_size) - decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, + hyps_lens) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) return decoder_out -def init_asr_model(configs): - if configs['cmvn_file'] is not None: - mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) - global_cmvn = GlobalCMVN( - torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) - else: - global_cmvn = None +class U2TransformerModel(U2Model): + def __init__(configs: dict): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'transformer') + assert encoder_type == 'transformer' + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + ctc = CTCDecoder(vocab_size, encoder.output_size()) + + self.__init__( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf']) + + +class U2ConformerModel(U2Model): + def __init__(configs: dict): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) + else: + global_cmvn = None - input_dim = configs['input_dim'] - vocab_size = configs['output_dim'] + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] - encoder_type = configs.get('encoder', 'conformer') - if encoder_type == 'conformer': + encoder_type = configs.get('encoder', 'conformer') + assert encoder_type == 'conformer' encoder = ConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) - else: - encoder = TransformerEncoder( - input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) - decoder = TransformerDecoder(vocab_size, - encoder.output_size(), - **configs['decoder_conf']) - ctc = CTCDecoder(vocab_size, encoder.output_size()) - model = U2Model( - vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf'], ) - return model + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + ctc = CTCDecoder(vocab_size, encoder.output_size()) + + self.__init__( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf']) diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index a0b80b84..4b0547d4 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -37,7 +37,7 @@ class BaseSubsampling(nn.Layer): self.pos_enc = pos_enc_class # window size = (1 + right_context) + (chunk_size -1) * subsampling_rate self.right_context = 0 - # stride = chunk_size * subsampling_rate + # stride = subsampling_rate * chunk_size self.subsampling_rate = 1 def position_encoding(self, offset: int, size: int) -> paddle.Tensor: diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 96f253b5..5f376c24 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -63,8 +63,13 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): def log_add(args: List[int]) -> float: - """ - Stable log add + """Stable log add + + Args: + args (List[int]): log scores + + Returns: + float: sum of log scores """ if all(a == -float('inf') for a in args): return -float('inf') -- GitLab