diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 3ff5be12e29f3c5d81635caf63851a103efe5ff0..30e4bb9c1a81282a9cbf639f9e6638f274462f32 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '_' + lang + '_' + sample_rate_str diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 5b5894fa37881b7a117ce017bea0857b9e164089..c4206f7e5e9b0e33c3f0b4259557ee234f01b9ce 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + if hasattr(self, 'model'): + logger.info('Model had been initialized.') + return + if label_file is None or ckpt_path is None: self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 self.cfg_path = os.path.join( @@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor): # model model_class = dynamic_import(model_type, model_alias) model_dict = paddle.load(self.ckpt_path) - self._model = model_class(extract_embedding=False) - self._model.set_state_dict(model_dict) - self._model.eval() + self.model = model_class(extract_embedding=False) + self.model.set_state_dict(model_dict) + self.model.eval() def preprocess(self, audio_file: Union[str, os.PathLike]): """ @@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - self._outputs['logits'] = self._model(self._inputs['feats']) + self._outputs['logits'] = self.model(self._inputs['feats']) def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: assert topk <= len(