提交 61e39dac 编写于 作者: K KP

Optimize model init.

上级 528c70e5
......@@ -137,6 +137,10 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str
......
......@@ -128,6 +128,10 @@ class CLSExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if label_file is None or ckpt_path is None:
self.res_path = self._get_pretrained_path(model_type) # panns_cnn14
self.cfg_path = os.path.join(
......@@ -154,9 +158,9 @@ class CLSExecutor(BaseExecutor):
# model
model_class = dynamic_import(model_type, model_alias)
model_dict = paddle.load(self.ckpt_path)
self._model = model_class(extract_embedding=False)
self._model.set_state_dict(model_dict)
self._model.eval()
self.model = model_class(extract_embedding=False)
self.model.set_state_dict(model_dict)
self.model.eval()
def preprocess(self, audio_file: Union[str, os.PathLike]):
"""
......@@ -192,7 +196,7 @@ class CLSExecutor(BaseExecutor):
"""
Model inference and result stored in self.output.
"""
self._outputs['logits'] = self._model(self._inputs['feats'])
self._outputs['logits'] = self.model(self._inputs['feats'])
def _generate_topk_label(self, result: np.ndarray, topk: int) -> str:
assert topk <= len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册