未验证 提交 4841f942 编写于 作者: Y YangZhou 提交者: GitHub

Merge pull request #2421 from THUzyt21/Deploy-fast-text-model-for-cli

[CLI]Deploy fast text model for cli
......@@ -42,3 +42,7 @@
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
```
- Faster Punctuation Restoration
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast
```
......@@ -43,3 +43,7 @@
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
```
- 快速标点恢复
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast
```
......@@ -20,10 +20,13 @@ from typing import Optional
from typing import Union
import paddle
import yaml
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddlespeech.text.models.ernie_linear import ErnieLinear
__all__ = ['TextExecutor']
......@@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
self.model.eval()
#init new models
def _init_from_path_new(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
vocab_file: Optional[os.PathLike]=None):
if hasattr(self, 'model'):
logger.debug('Model had been initialized.')
return
self.task = task
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.vocab_file = os.path.abspath(vocab_file)
model_name = model_type[:model_type.rindex('_')]
if self.task == 'punc':
# punc list
self._punc_list = []
with open(self.vocab_file, 'r') as f:
for line in f:
self._punc_list.append(line.strip())
# model
with open(self.cfg_path) as f:
config = CfgNode(yaml.safe_load(f))
self.model = ErnieLinear(**config["model"])
_, tokenizer_class = self.task_resource.get_model_class(model_name)
state_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
if 'fast' not in model_type:
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
else:
self.tokenizer = tokenizer_class.from_pretrained(
'ernie-3.0-mini-zh')
else:
raise NotImplementedError
def _clean_text(self, text):
text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
......@@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
else:
raise NotImplementedError
def postprocess(self) -> Union[str, os.PathLike]:
def postprocess(self, isNewTrainer: bool=False) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
......@@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels)
if isNewTrainer:
self._punc_list = [0] + self._punc_list
text = ''
for t, l in zip(tokens, labels):
text += t
if l != 0: # Non punc.
text += self._punc_list[l]
return text
else:
raise NotImplementedError
......@@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
paddle.set_device(device)
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess() # Retrieve result of text task.
#Here is old version models
if model in ['ernie_linear_p7_wudao', 'ernie_linear_p3_wudao']:
paddle.set_device(device)
self._init_from_path(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess() # Retrieve result of text task.
#Add new way to infer
else:
paddle.set_device(device)
self._init_from_path_new(task, model, lang, config, ckpt_path,
punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess(isNewTrainer=True)
return res
......@@ -51,6 +51,10 @@ model_alias = {
"paddlespeech.text.models:ErnieLinear",
"paddlenlp.transformers:ErnieTokenizer"
],
"ernie_linear_p3_wudao": [
"paddlespeech.text.models:ErnieLinear",
"paddlenlp.transformers:ErnieTokenizer"
],
# ---------------------------------
# -------------- TTS --------------
......
......@@ -529,7 +529,7 @@ text_dynamic_pretrained_models = {
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
},
"ernie_linear_p3_wudao-punc-zh": {
'1.0': {
......@@ -543,8 +543,22 @@ text_dynamic_pretrained_models = {
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
},
"ernie_linear_p3_wudao_fast-punc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao_fast-punc-zh.tar.gz',
'md5':
'c93f9594119541a5dbd763381a751d08',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
}
}
}
# ---------------------------------
......
......@@ -7,7 +7,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
paddlespeech cls --input ./cat.wav --topk 10
# Punctuation_restoration
paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast
# Speech_recognition
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册