From b9e3e49305983ff1b07d8d649dcadebfb1a71e32 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 15 Jun 2022 07:48:14 +0000 Subject: [PATCH] refactor stream asr and fix ds2 stream bug --- demos/streaming_asr_server/test.sh | 2 +- .../asr/online/{ => python}/asr_engine.py | 160 ++++++++++-------- paddlespeech/server/engine/engine_factory.py | 5 +- 3 files changed, 96 insertions(+), 71 deletions(-) rename paddlespeech/server/engine/asr/online/{ => python}/asr_engine.py (96%) diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index f3075454..f09068d4 100755 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav # read the wav and pass it to only streaming asr service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav -paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wav +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav # read the wav and call streaming and punc service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py similarity index 96% rename from paddlespeech/server/engine/asr/online/asr_engine.py rename to paddlespeech/server/engine/asr/online/python/asr_engine.py index f230b8b9..9801a6fc 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -121,13 +121,14 @@ class PaddleASRConnectionHanddler: raise ValueError(f"Not supported: {self.model_type}") def model_reset(self): - if "deepspeech2" in self.model_type: - return - # cache for audio and feat self.remained_wav = None self.cached_feat = None + + if "deepspeech2" in self.model_type: + return + ## conformer # cache for conformer online self.subsampling_cache = None @@ -697,6 +698,67 @@ class ASRServerExecutor(ASRExecutor): self.task_resource = CommonTaskResource( task='asr', model_format='dynamic', inference_mode='online') + def update_config(self)->None: + if "deepspeech2" in self.model_type: + with UpdateConfig(self.config): + # download lm + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in self.model_type or "transformer" in self.model_type: + with UpdateConfig(self.config): + logger.info("start to create the stream conformer asr engine") + # update the decoding method + if self.decode_method: + self.config.decode.decoding_method = self.decode_method + # update num_decoding_left_chunks + if self.num_decoding_left_chunks: + assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding_method = "attention_rescoring" + + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" + else: + raise Exception(f"not support: {self.model_type}") + + def init_model(self) -> None: + if "deepspeech2" in self.model_type : + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + elif "conformer" in self.model_type or "transformer" in self.model_type : + # load model + # model_type: {model_name}_{dataset} + model_name = self.model_type[:self.model_type.rindex('_')] + logger.info(f"model name: {model_name}") + model_class = self.task_resource.get_model_class(model_name) + model = model_class.from_config(self.config) + self.model = model + self.model.set_state_dict(paddle.load(self.am_model)) + self.model.eval() + else: + raise Exception(f"not support: {self.model_type}") + + def _init_from_path(self, model_type: str=None, am_model: Optional[os.PathLike]=None, @@ -718,8 +780,13 @@ class ASRServerExecutor(ASRExecutor): self.model_type = model_type self.sample_rate = sample_rate + self.decode_method = decode_method + self.num_decoding_left_chunks = num_decoding_left_chunks + # conf for paddleinference predictor or onnx + self.am_predictor_conf = am_predictor_conf logger.info(f"model_type: {self.model_type}") + sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(model_tag=tag) @@ -763,62 +830,10 @@ class ASRServerExecutor(ASRExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) - if "deepspeech2" in model_type: - with UpdateConfig(self.config): - # download lm - self.config.decode.lang_model_path = os.path.join( - MODEL_HOME, 'language_model', - self.config.decode.lang_model_path) - - lm_url = self.task_resource.res_dict['lm_url'] - lm_md5 = self.task_resource.res_dict['lm_md5'] - logger.info(f"Start to load language model {lm_url}") - self.download_lm( - lm_url, - os.path.dirname(self.config.decode.lang_model_path), lm_md5) - - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - elif "conformer" in model_type or "transformer" in model_type: - with UpdateConfig(self.config): - logger.info("start to create the stream conformer asr engine") - # update the decoding method - if decode_method: - self.config.decode.decoding_method = decode_method - # update num_decoding_left_chunks - if num_decoding_left_chunks: - assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" - self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks - # we only support ctc_prefix_beam_search and attention_rescoring dedoding method - # Generally we set the decoding_method to attention_rescoring - if self.config.decode.decoding_method not in [ - "ctc_prefix_beam_search", "attention_rescoring" - ]: - logger.info( - "we set the decoding_method to attention_rescoring") - self.config.decode.decoding_method = "attention_rescoring" - - assert self.config.decode.decoding_method in [ - "ctc_prefix_beam_search", "attention_rescoring" - ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" - - # load model - model_name = model_type[:model_type.rindex( - '_')] # model_type: {model_name}_{dataset} - logger.info(f"model name: {model_name}") - model_class = self.task_resource.get_model_class(model_name) - model = model_class.from_config(self.config) - self.model = model - self.model.set_state_dict(paddle.load(self.am_model)) - self.model.eval() - else: - raise Exception(f"not support: {model_type}") + self.update_config() + + # AM predictor + self.init_model() logger.info(f"create the {model_type} model success") return True @@ -835,6 +850,22 @@ class ASREngine(BaseEngine): super(ASREngine, self).__init__() logger.info("create the online asr engine resource instance") + + def init_model(self) -> bool: + if not self.executor._init_from_path( + model_type=self.config.model_type, + am_model=self.config.am_model, + am_params=self.config.am_params, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, + am_predictor_conf=self.config.am_predictor_conf): + return False + return True + + def init(self, config: dict) -> bool: """init engine resource @@ -860,16 +891,7 @@ class ASREngine(BaseEngine): logger.info(f"paddlespeech_server set the device: {self.device}") - if not self.executor._init_from_path( - model_type=self.config.model_type, - am_model=self.config.am_model, - am_params=self.config.am_params, - lang=self.config.lang, - sample_rate=self.config.sample_rate, - cfg_path=self.config.cfg_path, - decode_method=self.config.decode_method, - num_decoding_left_chunks=self.config.num_decoding_left_chunks, - am_predictor_conf=self.config.am_predictor_conf): + if not self.init_model(): logger.error( "Init the ASR server occurs error, please check the server configuration yaml" ) diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 5fdaacce..019e4684 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -26,7 +26,10 @@ class EngineFactory(object): from paddlespeech.server.engine.asr.python.asr_engine import ASREngine return ASREngine() elif engine_name == 'asr' and engine_type == 'online': - from paddlespeech.server.engine.asr.online.asr_engine import ASREngine + from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine + return ASREngine() + elif engine_name == 'asr' and engine_type == 'online-onnx': + from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine return ASREngine() elif engine_name == 'tts' and engine_type == 'inference': from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine -- GitLab