From 30b3e237e24d71d371964404d010d9f8a53f8acc Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 24 Sep 2021 10:40:56 +0000 Subject: [PATCH] optimize the 1xt2x --- deepspeech/frontend/featurizer/text_featurizer.py | 3 --- examples/1xt2x/aishell/local/download_model.sh | 2 +- examples/1xt2x/aishell/run.sh | 3 ++- examples/1xt2x/baidu_en8k/local/download_model.sh | 2 +- examples/1xt2x/baidu_en8k/run.sh | 3 ++- examples/1xt2x/librispeech/local/download_model.sh | 2 +- examples/1xt2x/librispeech/run.sh | 3 ++- examples/1xt2x/src_deepspeech2x/test_model.py | 14 ++++++++------ 8 files changed, 17 insertions(+), 15 deletions(-) diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index c9f324ff..10ea6924 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -84,8 +84,6 @@ class TextFeaturizer(): tokens = self.tokenize(text) ids = [] for token in tokens: - if '' in self.vocab_dict and token == ' ': - token = '' token = token if token in self.vocab_dict else self.unk ids.append(self.vocab_dict[token]) return ids @@ -201,7 +199,6 @@ class TextFeaturizer(): """Load vocabulary from file.""" vocab_list = load_dict(vocab_filepath, maskctc) assert vocab_list is not None - assert SPACE in vocab_list id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) diff --git a/examples/1xt2x/aishell/local/download_model.sh b/examples/1xt2x/aishell/local/download_model.sh index be4fa216..ffa2f810 100644 --- a/examples/1xt2x/aishell/local/download_model.sh +++ b/examples/1xt2x/aishell/local/download_model.sh @@ -10,7 +10,7 @@ ckpt_dir=$1 . ${MAIN_ROOT}/utils/utility.sh URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' -MD5=4ade113c69ea291b8ce5ec6a03296659 +MD5=87e7577d4bea737dbf3e8daab37aa808 TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz diff --git a/examples/1xt2x/aishell/run.sh b/examples/1xt2x/aishell/run.sh index 0898f255..1ccac1c3 100755 --- a/examples/1xt2x/aishell/run.sh +++ b/examples/1xt2x/aishell/run.sh @@ -7,6 +7,7 @@ stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 model_type=offline +gpus=2 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -22,6 +23,6 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 fi diff --git a/examples/1xt2x/baidu_en8k/local/download_model.sh b/examples/1xt2x/baidu_en8k/local/download_model.sh index 54cf4210..a8fbc31e 100644 --- a/examples/1xt2x/baidu_en8k/local/download_model.sh +++ b/examples/1xt2x/baidu_en8k/local/download_model.sh @@ -10,7 +10,7 @@ ckpt_dir=$1 . ${MAIN_ROOT}/utils/utility.sh URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' -MD5=fdabeb6c96963ac85d9188f0275c6a1b +MD5=c1676be8505cee436e6f312823e9008c TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz diff --git a/examples/1xt2x/baidu_en8k/run.sh b/examples/1xt2x/baidu_en8k/run.sh index c0f9ae45..b7f69f6b 100755 --- a/examples/1xt2x/baidu_en8k/run.sh +++ b/examples/1xt2x/baidu_en8k/run.sh @@ -7,6 +7,7 @@ stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 model_type=offline +gpus=0 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -22,6 +23,6 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 fi diff --git a/examples/1xt2x/librispeech/local/download_model.sh b/examples/1xt2x/librispeech/local/download_model.sh index 2c388efa..375d6640 100644 --- a/examples/1xt2x/librispeech/local/download_model.sh +++ b/examples/1xt2x/librispeech/local/download_model.sh @@ -10,7 +10,7 @@ ckpt_dir=$1 . ${MAIN_ROOT}/utils/utility.sh URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' -MD5=7b0f582fe2f5a840b840e7ee52246bc5 +MD5=a06d9aadb560ea113984dc98d67232c8 TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz diff --git a/examples/1xt2x/librispeech/run.sh b/examples/1xt2x/librispeech/run.sh index 4671c40d..8c667de2 100755 --- a/examples/1xt2x/librispeech/run.sh +++ b/examples/1xt2x/librispeech/run.sh @@ -7,6 +7,7 @@ stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 model_type=offline +gpus=1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -22,5 +23,5 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 fi diff --git a/examples/1xt2x/src_deepspeech2x/test_model.py b/examples/1xt2x/src_deepspeech2x/test_model.py index 9f0cd836..203a3bac 100644 --- a/examples/1xt2x/src_deepspeech2x/test_model.py +++ b/examples/1xt2x/src_deepspeech2x/test_model.py @@ -26,6 +26,7 @@ from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel from src_deepspeech2x.models.ds2 import DeepSpeech2Model from yacs.config import CfgNode +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.sampler import SortagradBatchSampler @@ -38,7 +39,6 @@ from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log -#from deepspeech.utils.log import Autolog logger = Log(__name__).getlog() @@ -272,6 +272,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): return default def __init__(self, config, args): + + self._text_featurizer = TextFeaturizer( + unit_type=config.collator.unit_type, vocab_filepath=None) super().__init__(config, args) def ordid2token(self, texts, texts_len): @@ -296,9 +299,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer vocab_list = self.test_loader.collate_fn.vocab_list - if "" in vocab_list: - space_id = vocab_list.index("") - vocab_list[space_id] = " " target_transcripts = self.ordid2token(texts, texts_len) @@ -337,6 +337,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) + result_transcripts = [ + self._text_featurizer.detokenize(item) + for item in result_transcripts + ] return result_transcripts @mp_tools.rank_zero_only @@ -367,8 +371,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate_type, num_ins, num_ins, errors_sum / len_refs) logger.info(msg) - # self.autolog.report() - def run_test(self): self.resume_or_scratch() try: -- GitLab