From 61e39daccc572429fba85ad6b65259f8c22d1bd1 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Tue, 7 Dec 2021 18:37:45 +0800 Subject: [PATCH] Optimize model init. --- paddlespeech/cli/asr/infer.py | 4 ++++ paddlespeech/cli/cls/infer.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 3ff5be12..30e4bb9c 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '_' + lang + '_' + sample_rate_str diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 5b5894fa..c4206f7e 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + if label_file is None or ckpt_path is None: self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 self.cfg_path = os.path.join( @@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor): # model model_class = dynamic_import(model_type, model_alias) model_dict = paddle.load(self.ckpt_path) - self._model = model_class(extract_embedding=False) - self._model.set_state_dict(model_dict) - self._model.eval() + self.model = model_class(extract_embedding=False) + self.model.set_state_dict(model_dict) + self.model.eval() def preprocess(self, audio_file: Union[str, os.PathLike]): """ @@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - self._outputs['logits'] = self._model(self._inputs['feats']) + self._outputs['logits'] = self.model(self._inputs['feats']) def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: assert topk <= len( -- GitLab