未验证 提交 57dcd0d1 编写于 作者: Z Zhao Yuting 提交者: GitHub

Update infer.py

change the infer in order to implement the new faster model for text
上级 b627666c
...@@ -20,10 +20,13 @@ from typing import Optional ...@@ -20,10 +20,13 @@ from typing import Optional
from typing import Union from typing import Union
import paddle import paddle
import yaml
from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.text.models.ernie_linear import ErnieLinear
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
...@@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor): ...@@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
self.model.eval() self.model.eval()
#init new models
def _init_from_path_new(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
vocab_file: Optional[os.PathLike]=None):
if hasattr(self, 'model'):
logger.debug('Model had been initialized.')
return
self.task = task
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.vocab_file = os.path.abspath(vocab_file)
model_name = model_type[:model_type.rindex('_')]
if self.task == 'punc':
# punc list
self._punc_list = []
with open(self.vocab_file, 'r') as f:
for line in f:
self._punc_list.append(line.strip())
# model
with open(self.cfg_path) as f:
config = CfgNode(yaml.safe_load(f))
self.model = ErnieLinear(**config["model"])
_, tokenizer_class = self.task_resource.get_model_class(model_name)
state_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
if 'fast' not in model_type:
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
else:
self.tokenizer = tokenizer_class.from_pretrained(
'ernie-3.0-mini-zh')
else:
raise NotImplementedError
def _clean_text(self, text): def _clean_text(self, text):
text = text.lower() text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text) text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
...@@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor): ...@@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
else: else:
raise NotImplementedError raise NotImplementedError
def postprocess(self) -> Union[str, os.PathLike]: def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]:
""" """
Output postprocess and return human-readable results such as texts and audio files. Output postprocess and return human-readable results such as texts and audio files.
""" """
...@@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor): ...@@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
input_ids[1:seq_len - 1]) input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist() labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels) assert len(tokens) == len(labels)
if isNewTrainer:
self._punc_list = [0] + self._punc_list
text = '' text = ''
for t, l in zip(tokens, labels): for t, l in zip(tokens, labels):
text += t text += t
if l != 0: # Non punc. if l != 0: # Non punc.
text += self._punc_list[l] text += self._punc_list[l]
return text return text
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor): ...@@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
""" """
Python API to call an executor. Python API to call an executor.
""" """
#Here is old version models
if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']:
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab) self._init_from_path(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text) self.preprocess(text)
self.infer() self.infer()
res = self.postprocess() # Retrieve result of text task. res = self.postprocess() # Retrieve result of text task.
#Add new way to infer
else:
paddle.set_device(device)
self._init_from_path_new(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess(isNewTrainer=True)
return res return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册