提交 61e39dac 编写于 作者: K KP

Optimize model init.

上级 528c70e5
...@@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor): ...@@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
......
...@@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor): ...@@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if label_file is None or ckpt_path is None: if label_file is None or ckpt_path is None:
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14 self.res_path = self._get_pretrained_path(model_type) # panns_cnn14
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
...@@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor): ...@@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor):
# model # model
model_class = dynamic_import(model_type, model_alias) model_class = dynamic_import(model_type, model_alias)
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self._model = model_class(extract_embedding=False) self.model = model_class(extract_embedding=False)
self._model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
self._model.eval() self.model.eval()
def preprocess(self, audio_file: Union[str, os.PathLike]): def preprocess(self, audio_file: Union[str, os.PathLike]):
""" """
...@@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor): ...@@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
self._outputs['logits'] = self._model(self._inputs['feats']) self._outputs['logits'] = self.model(self._inputs['feats'])
def _generate_topk_label(self, result: np.ndarray, topk: int) -> str: def _generate_topk_label(self, result: np.ndarray, topk: int) -> str:
assert topk <= len( assert topk <= len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册