From 5b446f63219af2acc3352554d6e90dcc04c00da4 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Wed, 15 Dec 2021 11:54:52 +0800 Subject: [PATCH] [Config]clear the u2 decode config for asr (#1107) * clear the u2 decode config * rename the vocab_filepath and cmvn_path --- paddlespeech/cli/asr/infer.py | 14 ++++------- paddlespeech/cli/st/infer.py | 9 +------- .../s2t/exps/deepspeech2/bin/test_hub.py | 2 +- paddlespeech/s2t/exps/deepspeech2/model.py | 2 +- paddlespeech/s2t/exps/u2/bin/test_wav.py | 8 +------ paddlespeech/s2t/exps/u2/model.py | 8 +------ paddlespeech/s2t/exps/u2_kaldi/model.py | 8 +------ paddlespeech/s2t/exps/u2_st/model.py | 14 ----------- .../frontend/featurizer/speech_featurizer.py | 2 +- .../frontend/featurizer/text_featurizer.py | 23 ++++++++++--------- paddlespeech/s2t/models/lm/dataset.py | 2 +- paddlespeech/s2t/models/u2/u2.py | 17 +++----------- paddlespeech/s2t/models/u2_st/u2_st.py | 14 ----------- paddlespeech/s2t/transform/cmvn.py | 16 ++++++++----- 14 files changed, 37 insertions(+), 102 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 00f21293..e020b501 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -179,7 +179,7 @@ class ASRExecutor(BaseExecutor): self.collate_fn_test = SpeechCollator.from_config(self.config) text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self.config.model.input_dim = self.collate_fn_test.feature_size self.config.model.output_dim = text_feature.vocab_size @@ -192,7 +192,7 @@ class ASRExecutor(BaseExecutor): res_path, self.config.collator.spm_model_prefix) text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self.config.model.input_dim = self.config.collator.feat_dim self.config.model.output_dim = text_feature.vocab_size @@ -279,7 +279,7 @@ class ASRExecutor(BaseExecutor): audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self._inputs["audio"] = audio self._inputs["audio_len"] = audio_len @@ -295,7 +295,7 @@ class ASRExecutor(BaseExecutor): """ text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) cfg = self.config.decoding audio = self._inputs["audio"] @@ -321,13 +321,7 @@ class ASRExecutor(BaseExecutor): audio_len, text_feature=text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, ctc_weight=cfg.ctc_weight, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index 6bb82821..fd32e3b4 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -180,7 +180,7 @@ class STExecutor(BaseExecutor): res_path, self.config.collator.spm_model_prefix) self.text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self.config.model.input_dim = self.config.collator.feat_dim self.config.model.output_dim = self.text_feature.vocab_size @@ -292,14 +292,7 @@ class STExecutor(BaseExecutor): audio_len, text_feature=self.text_feature, decoding_method=cfg.decoding_method, - lang_model_path=None, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, - ctc_weight=cfg.ctc_weight, word_reward=cfg.word_reward, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_hub.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_hub.py index b8544dc2..cf2ca0d6 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_hub.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_hub.py @@ -41,7 +41,7 @@ class DeepSpeech2Tester_hub(): self.audio_file = args.audio_file self.collate_fn_test = SpeechCollator.from_config(config) self._text_featurizer = TextFeaturizer( - unit_type=config.collator.unit_type, vocab_filepath=None) + unit_type=config.collator.unit_type, vocab=None) def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): result_transcripts = self.model.decode( diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 3e4ff1a8..a0b69d64 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -286,7 +286,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def __init__(self, config, args): super().__init__(config, args) self._text_featurizer = TextFeaturizer( - unit_type=config.collator.unit_type, vocab_filepath=None) + unit_type=config.collator.unit_type, vocab=None) def ordid2token(self, texts, texts_len): """ ord() id to chr() chr """ diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index a9450129..556316ec 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -44,7 +44,7 @@ class U2Infer(): self.text_feature = TextFeaturizer( unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, + vocab=config.collator.vocab_filepath, spm_model_prefix=config.collator.spm_model_prefix) paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') @@ -91,13 +91,7 @@ class U2Infer(): ilen, text_feature=self.text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, ctc_weight=cfg.ctc_weight, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d448021c..404058ed 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -437,7 +437,7 @@ class U2Tester(U2Trainer): super().__init__(config, args) self.text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list @@ -469,13 +469,7 @@ class U2Tester(U2Trainer): audio_len, text_feature=self.text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, ctc_weight=cfg.ctc_weight, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index 43e31a60..9b8274ad 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -393,7 +393,7 @@ class U2Tester(U2Trainer): super().__init__(config, args) self.text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, + vocab=self.config.collator.vocab_filepath, spm_model_prefix=self.config.collator.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list @@ -425,13 +425,7 @@ class U2Tester(U2Trainer): audio_len, text_feature=self.text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, ctc_weight=cfg.ctc_weight, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 3ec2c920..a3b39df7 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -437,14 +437,7 @@ class U2STTester(U2STTrainer): audio_len, text_feature=text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, - ctc_weight=cfg.ctc_weight, word_reward=cfg.word_reward, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, @@ -475,14 +468,7 @@ class U2STTester(U2STTrainer): audio_len, text_feature=text_feature, decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, - ctc_weight=cfg.ctc_weight, word_reward=cfg.word_reward, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, diff --git a/paddlespeech/s2t/frontend/featurizer/speech_featurizer.py b/paddlespeech/s2t/frontend/featurizer/speech_featurizer.py index 591df96e..9dc86829 100644 --- a/paddlespeech/s2t/frontend/featurizer/speech_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/speech_featurizer.py @@ -55,7 +55,7 @@ class SpeechFeaturizer(): self.text_feature = TextFeaturizer( unit_type=unit_type, - vocab_filepath=vocab_filepath, + vocab=vocab_filepath, spm_model_prefix=spm_model_prefix, maskctc=maskctc) self.vocab_size = self.text_feature.vocab_size diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 812be6e4..0c0fa5e2 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains the text featurizer class.""" from pprint import pformat +from typing import Union import sentencepiece as spm @@ -31,11 +32,7 @@ __all__ = ["TextFeaturizer"] class TextFeaturizer(): - def __init__(self, - unit_type, - vocab_filepath, - spm_model_prefix=None, - maskctc=False): + def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False): """Text featurizer, for processing or extracting features from text. Currently, it supports char/word/sentence-piece level tokenizing and conversion into @@ -44,7 +41,7 @@ class TextFeaturizer(): Args: unit_type (str): unit type, e.g. char, word, spm - vocab_filepath (str): Filepath to load vocabulary for token indices conversion. + vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list. spm_model_prefix (str, optional): spm model prefix. Defaults to None. """ assert unit_type in ('char', 'spm', 'word') @@ -52,12 +49,12 @@ class TextFeaturizer(): self.unk = UNK self.maskctc = maskctc - if vocab_filepath: + if vocab: self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( - vocab_filepath, maskctc) + vocab, maskctc) self.vocab_size = len(self.vocab_list) else: - logger.warning("TextFeaturizer: not have vocab file.") + logger.warning("TextFeaturizer: not have vocab file or vocab list.") if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -207,9 +204,13 @@ class TextFeaturizer(): return decode(tokens) - def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool): + def _load_vocabulary_from_file(self, vocab: Union[str, list], + maskctc: bool): """Load vocabulary from file.""" - vocab_list = load_dict(vocab_filepath, maskctc) + if isinstance(vocab, list): + vocab_list = vocab + else: + vocab_list = load_dict(vocab, maskctc) assert vocab_list is not None logger.debug(f"Vocab: {pformat(vocab_list)}") diff --git a/paddlespeech/s2t/models/lm/dataset.py b/paddlespeech/s2t/models/lm/dataset.py index 4059dfe2..25a47be6 100644 --- a/paddlespeech/s2t/models/lm/dataset.py +++ b/paddlespeech/s2t/models/lm/dataset.py @@ -42,7 +42,7 @@ class TextCollatorSpm(): assert (vocab_filepath is not None) self.text_featurizer = TextFeaturizer( unit_type=unit_type, - vocab_filepath=vocab_filepath, + vocab=vocab_filepath, spm_model_prefix=spm_model_prefix) self.eos_id = self.text_featurizer.eos_id self.blank_id = self.text_featurizer.blank_id diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 8053ed3a..fdcab018 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -717,13 +717,7 @@ class U2BaseModel(ASRInterface, nn.Layer): feats_lengths: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, - lang_model_path: str, - beam_alpha: float, - beam_beta: float, beam_size: int, - cutoff_prob: float, - cutoff_top_n: int, - num_processes: int, ctc_weight: float=0.0, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, @@ -737,13 +731,7 @@ class U2BaseModel(ASRInterface, nn.Layer): decoding_method (str): decoding mode, e.g. 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path (str): lm path. - beam_alpha (float): lm weight. - beam_beta (float): length penalty. beam_size (int): beam size for search - cutoff_prob (float): for prune. - cutoff_top_n (int): for prune. - num_processes (int): ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. <0: for decoding, use full chunk. @@ -839,12 +827,13 @@ class U2Model(U2DecodeModel): def __init__(self, configs: dict): vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) + model_conf = configs.get('model_conf', dict()) super().__init__( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, - **configs['model_conf']) + **model_conf) @classmethod def _init_from_config(cls, configs: dict): @@ -893,7 +882,7 @@ class U2Model(U2DecodeModel): **configs['decoder_conf']) # ctc decoder and ctc loss - model_conf = configs['model_conf'] + model_conf = configs.get('model_conf', dict()) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) grad_norm_type = model_conf.get('ctc_grad_norm_type', None) ctc = CTCDecoder( diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 3a23804f..8b07e389 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -522,14 +522,7 @@ class U2STBaseModel(nn.Layer): feats_lengths: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, - lang_model_path: str, - beam_alpha: float, - beam_beta: float, beam_size: int, - cutoff_prob: float, - cutoff_top_n: int, - num_processes: int, - ctc_weight: float=0.0, word_reward: float=0.0, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, @@ -543,14 +536,7 @@ class U2STBaseModel(nn.Layer): decoding_method (str): decoding mode, e.g. 'fullsentence', 'simultaneous' - lang_model_path (str): lm path. - beam_alpha (float): lm weight. - beam_beta (float): length penalty. beam_size (int): beam size for search - cutoff_prob (float): for prune. - cutoff_top_n (int): for prune. - num_processes (int): - ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py index aa1e6b44..2db0070b 100644 --- a/paddlespeech/s2t/transform/cmvn.py +++ b/paddlespeech/s2t/transform/cmvn.py @@ -168,13 +168,17 @@ class GlobalCMVN(): norm_means=True, norm_vars=True, std_floor=1.0e-20): - self.cmvn_path = cmvn_path + # cmvn_path: Option[str, dict] + cmvn = cmvn_path + self.cmvn = cmvn self.norm_means = norm_means self.norm_vars = norm_vars self.std_floor = std_floor - - with open(cmvn_path) as f: - cmvn_stats = json.load(f) + if isinstance(cmvn, dict): + cmvn_stats = cmvn + else: + with open(cmvn) as f: + cmvn_stats = json.load(f) self.count = cmvn_stats['frame_num'] self.mean = np.array(cmvn_stats['mean_stat']) / self.count self.square_sums = np.array(cmvn_stats['var_stat']) @@ -183,8 +187,8 @@ class GlobalCMVN(): def __repr__(self): return f"""{self.__class__.__name__}( - cmvn_path={self.cmvn_path}, - norm_means={self.norm_means}, + cmvn_path={self.cmvn}, + norm_means={self.norm_means}, norm_vars={self.norm_vars},)""" def __call__(self, x, uttid=None): -- GitLab