提交 43f4d47b 编写于 作者: H huangyuxin

add the call in infer.py

......@@ -14,6 +14,7 @@
import os
from abc import ABC
from abc import abstractmethod
from typing import List
from typing import Union
import paddle
......@@ -64,3 +65,17 @@ class BaseExecutor(ABC):
Output postprocess and return human-readable results such as texts and audio files.
"""
pass
@abstractmethod
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
pass
@abstractmethod
def __call__(self, *arg, **kwargs):
"""
Python API to call an executor.
"""
pass
......@@ -18,13 +18,14 @@ from typing import List
from typing import Optional
from typing import Union
import soundfile
import paddle
from paddlespeech.cli.executor import BaseExecutor
from paddlespeech.cli.utils import cli_register
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import logger
from paddlespeech.cli.utils import MODEL_HOME
import soundfile
from ..executor import BaseExecutor
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
......@@ -55,29 +56,6 @@ model_alias = {
"wenetspeech": "paddlespeech.s2t.models.u2:U2Model",
}
pretrain_model_alias = {
"ds2_online_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz",
"", ""
],
"ds2_offline_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz",
"", ""
],
"transformer_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz",
"", ""
],
"conformer_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz",
"", ""
],
"wenetspeech_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz",
"conf/conformer.yaml", "exp/conformer/checkpoints/wenetspeech"
],
}
@cli_register(
name='paddlespeech.s2t', description='Speech to text infer command.')
......@@ -107,7 +85,6 @@ class S2TExecutor(BaseExecutor):
self.parser.add_argument(
'--input',
type=str,
default="../Downloads/asr-demo-1.wav",
help='Audio file to recognize.')
self.parser.add_argument(
'--device',
......@@ -155,7 +132,9 @@ class S2TExecutor(BaseExecutor):
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
# Enter the path of model root
os.chdir(res_path)
#Init body.
parser_args = self.parser_args
paddle.set_device(parser_args.device)
......@@ -206,7 +185,7 @@ class S2TExecutor(BaseExecutor):
config = self.config
audio_file = input
#print("audio_file", audio_file)
logger.info("audio_file"+ audio_file)
logger.info("audio_file" + audio_file)
self.sr = config.collator.target_sample_rate
......@@ -307,7 +286,11 @@ class S2TExecutor(BaseExecutor):
return self.result_transcripts
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
self.parser_args = self.parser.parse_args(argv)
print(self.parser_args)
model = self.parser_args.model
lang = self.parser_args.lang
......@@ -317,17 +300,20 @@ class S2TExecutor(BaseExecutor):
device = self.parser_args.device
try:
self._init_from_path(model, lang, config, ckpt_path)
self.preprocess(audio_file)
self.infer()
res = self.postprocess() # Retrieve result of s2t.
logger.info(res)
res = self(model, lang, config, ckpt_path, audio_file, device)
print(res)
return True
except Exception as e:
print(e)
return False
def __call__(self, model, lang, config, ckpt_path, audio_file, device):
"""
Python API to call an executor.
"""
self._init_from_path(model, lang, config, ckpt_path)
self.preprocess(audio_file)
self.infer()
res = self.postprocess() # Retrieve result of s2t.
if __name__ == "__main__":
exe = S2TExecutor()
exe.execute('')
return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册