From c2eb9ad20513783aec0ebcd39b00bbf0d3d9ac44 Mon Sep 17 00:00:00 2001 From: WilliamZhang06 Date: Mon, 14 Feb 2022 11:26:06 +0800 Subject: [PATCH] added asr engine and fixed bugs, test=doc --- .../speechserving/conf/application.yaml | 2 + speechserving/speechserving/conf/asr/asr.yaml | 4 +- .../engine/asr/python/asr_engine.py | 67 ++++++++++------- .../speechserving/engine/base_engine.py | 2 + .../speechserving/engine/engine_factory.py | 4 +- speechserving/speechserving/main.py | 22 ++++-- speechserving/speechserving/restful/api.py | 5 +- .../speechserving/restful/asr_api.py | 73 +++++++++++++------ .../speechserving/restful/request.py | 5 +- .../speechserving/restful/response.py | 19 +++++ .../speechserving/restful/tts_api.py | 5 +- speechserving/speechserving/utils/__init__.py | 13 ++++ speechserving/speechserving/utils/config.py | 2 +- speechserving/speechserving/utils/errors.py | 2 +- speechserving/speechserving/utils/util.py | 6 +- 15 files changed, 159 insertions(+), 72 deletions(-) diff --git a/speechserving/speechserving/conf/application.yaml b/speechserving/speechserving/conf/application.yaml index 8c4d9bc6..c8d71f2f 100644 --- a/speechserving/speechserving/conf/application.yaml +++ b/speechserving/speechserving/conf/application.yaml @@ -10,6 +10,8 @@ port: 8090 # CONFIG FILE # ################################################################## # add engine type (Options: asr, tts) and config file here. + engine_backend: asr: 'conf/asr/asr.yaml' tts: 'conf/tts/tts.yaml' + diff --git a/speechserving/speechserving/conf/asr/asr.yaml b/speechserving/speechserving/conf/asr/asr.yaml index 39df2548..4c3b0a67 100644 --- a/speechserving/speechserving/conf/asr/asr.yaml +++ b/speechserving/speechserving/conf/asr/asr.yaml @@ -1,7 +1,7 @@ model: 'conformer_wenetspeech' lang: 'zh' sample_rate: 16000 -cfg_path: "/home/users/zhangyinhui/.paddlespeech/models/conformer_wenetspeech-zh-16k/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar/model.yaml" -ckpt_path: "/home/users/zhangyinhui/.paddlespeech/models/conformer_wenetspeech-zh-16k/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar/exp/conformer/checkpoints/wenetspeech" +cfg_path: +ckpt_path: decode_method: 'attention_rescoring' force_yes: False diff --git a/speechserving/speechserving/engine/asr/python/asr_engine.py b/speechserving/speechserving/engine/asr/python/asr_engine.py index bb1596af..e8289332 100644 --- a/speechserving/speechserving/engine/asr/python/asr_engine.py +++ b/speechserving/speechserving/engine/asr/python/asr_engine.py @@ -11,23 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import paddle import io -import soundfile import os -import librosa from typing import List from typing import Optional from typing import Union -from paddlespeech.cli.log import logger +import librosa +import paddle +import soundfile +from engine.base_engine import BaseEngine + from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig - -from engine.base_engine import BaseEngine from utils.config import get_config __all__ = ['ASREngine'] @@ -141,42 +141,55 @@ class ASREngine(BaseEngine): Args: metaclass: Defaults to Singleton. """ + def __init__(self): super(ASREngine, self).__init__() - def init(self, config_file: str): + def init(self, config_file: str) -> bool: + """init engine resource + + Args: + config_file (str): config file + Returns: + bool: init failed or success + """ + self.input = None + self.output = None self.executor = ASRServerExecutor() - self.config = get_config(config_file) - paddle.set_device(paddle.get_device()) - self.executor._init_from_path( - self.config.model, - self.config.lang, - self.config.sample_rate, - self.config.cfg_path, - self.config.decode_method, - self.config.ckpt_path) + try: + self.config = get_config(config_file) + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + self.config.model, self.config.lang, self.config.sample_rate, + self.config.cfg_path, self.config.decode_method, + self.config.ckpt_path) + except: + logger.info("Initialize ASR server engine Failed.") + return False logger.info("Initialize ASR server engine successfully.") - - self.input = None - self.output = None + return True def run(self, audio_data): + """engine run - if self.executor._check(io.BytesIO(audio_data), self.config.sample_rate, self.config.force_yes): + Args: + audio_data (bytes): base64.b64decode + """ + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start run asr engine") self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) self.executor.infer(self.config.model) - self.output = self.executor.postprocess() # Retrieve result of asr. + self.output = self.executor.postprocess() # Retrieve result of asr. else: logger.info("file check failed!") - - logger.info("start run asr engine") + self.output = None def postprocess(self): - + """postprocess + """ return self.output - - - diff --git a/speechserving/speechserving/engine/base_engine.py b/speechserving/speechserving/engine/base_engine.py index 36048dcc..0cc20209 100644 --- a/speechserving/speechserving/engine/base_engine.py +++ b/speechserving/speechserving/engine/base_engine.py @@ -18,6 +18,8 @@ from typing import Union from pattern_singleton import Singleton +__all__ = ['BaseEngine'] + class BaseEngine(metaclass=Singleton): """ diff --git a/speechserving/speechserving/engine/engine_factory.py b/speechserving/speechserving/engine/engine_factory.py index 336a9a6f..2b9f9db7 100644 --- a/speechserving/speechserving/engine/engine_factory.py +++ b/speechserving/speechserving/engine/engine_factory.py @@ -14,10 +14,12 @@ from engine.asr.python.asr_engine import ASREngine from engine.tts.python.tts_engine import TTSEngine +__all__ = ['EngineFactory'] + class EngineFactory(object): @staticmethod - def get_engine(engine_name): + def get_engine(engine_name: str): if engine_name == 'asr': return ASREngine() elif engine_name == 'tts': diff --git a/speechserving/speechserving/main.py b/speechserving/speechserving/main.py index 3b367418..6d4891c7 100644 --- a/speechserving/speechserving/main.py +++ b/speechserving/speechserving/main.py @@ -12,21 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse + import uvicorn import yaml +from engine.engine_factory import EngineFactory from fastapi import FastAPI - from restful.api import setup_router -from utils.log import logger + from utils.config import get_config -from engine.engine_factory import EngineFactory +from utils.log import logger app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") def init(config): - """ system initialization + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: """ # init api api_list = list(config.engine_backend) @@ -34,10 +41,11 @@ def init(config): app.include_router(api_router) # init engine - engine_list = [] + engine_pool = [] for engine in config.engine_backend: - engine_list.append(EngineFactory.get_engine(engine_name=engine)) - engine_list[-1].init(config_file=config.engine_backend[engine]) + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init(config_file=config.engine_backend[engine]): + return False return True diff --git a/speechserving/speechserving/restful/api.py b/speechserving/speechserving/restful/api.py index c5539f24..bdff935a 100644 --- a/speechserving/speechserving/restful/api.py +++ b/speechserving/speechserving/restful/api.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List + from fastapi import APIRouter -from .tts_api import router as tts_router from .asr_api import router as asr_router +from .tts_api import router as tts_router _router = APIRouter() + def setup_router(api_list: List): for api_name in api_list: @@ -30,4 +32,3 @@ def setup_router(api_list: List): pass return _router - diff --git a/speechserving/speechserving/restful/asr_api.py b/speechserving/speechserving/restful/asr_api.py index ab2c8048..6ac647bc 100644 --- a/speechserving/speechserving/restful/asr_api.py +++ b/speechserving/speechserving/restful/asr_api.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fastapi import APIRouter import base64 +import traceback +from typing import Union from engine.asr.python.asr_engine import ASREngine -from .response import ASRResponse -from .request import ASRRequest +from fastapi import APIRouter +from .request import ASRRequest +from .response import ASRResponse +from .response import ErrorResponse +from utils.errors import ErrorCode +from utils.errors import failed_response +from utils.exception import ServerBaseException router = APIRouter() + @router.get('/paddlespeech/asr/help') def help(): """help @@ -28,10 +35,23 @@ def help(): Returns: json: [description] """ - return {'hello': 'world'} + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "tts server", + "input": "base64 string of wavfile", + "output": "transcription" + } + } + return response -@router.post("/paddlespeech/asr", response_model=ASRResponse) +@router.post( + "/paddlespeech/asr", response_model=Union[ASRResponse, ErrorResponse]) def asr(request_body: ASRRequest): """asr api @@ -41,21 +61,28 @@ def asr(request_body: ASRRequest): Returns: json: [description] """ - audio_data = base64.b64decode(request_body.audio) - # single - asr_engine = ASREngine() - asr_engine.run(audio_data) - asr_results = asr_engine.postprocess() - - json_body = { - "success": True, - "code": 0, - "message": { - "description": "success" - }, - "result": { - "transcription": asr_results - } - } - - return json_body + try: + # single + audio_data = base64.b64decode(request_body.audio) + asr_engine = ASREngine() + asr_engine.run(audio_data) + asr_results = asr_engine.postprocess() + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "transcription": asr_results + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/speechserving/speechserving/restful/request.py b/speechserving/speechserving/restful/request.py index 9ebc6d75..2be5f0e5 100644 --- a/speechserving/speechserving/restful/request.py +++ b/speechserving/speechserving/restful/request.py @@ -19,7 +19,6 @@ from pydantic import BaseModel __all__ = ['ASRRequest', 'TTSRequest'] - #****************************************************************************************/ #************************************ ASR request ***************************************/ #****************************************************************************************/ @@ -31,14 +30,14 @@ class ASRRequest(BaseModel): "audio_format": "wav", "sample_rate": 16000, "lang": "zh_cn", - "ptt":false + "punc":false } """ audio: str audio_format: str sample_rate: int lang: str - ptt: Optional[bool] = None + punc: Optional[bool] = None #****************************************************************************************/ diff --git a/speechserving/speechserving/restful/response.py b/speechserving/speechserving/restful/response.py index db24f531..ab5e395b 100644 --- a/speechserving/speechserving/restful/response.py +++ b/speechserving/speechserving/restful/response.py @@ -86,3 +86,22 @@ class TTSResponse(BaseModel): code: int message: Message result: TTSResult + + +#****************************************************************************************/ +#********************************** Error response **************************************/ +#****************************************************************************************/ +class ErrorResponse(BaseModel): + """ + response example + { + "success": false, + "code": 0, + "message": { + "description": "Unknown error occurred." + } + } + """ + success: bool + code: int + message: Message diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index e9dcfa16..a160e31d 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import traceback +from typing import Union from engine.tts.python.tts_engine import TTSEngine from fastapi import APIRouter from .request import TTSRequest +from .response import ErrorResponse from .response import TTSResponse from utils.errors import ErrorCode from utils.errors import failed_response @@ -47,7 +49,8 @@ def help(): return response -@router.post("/paddlespeech/tts", response_model=TTSResponse) +@router.post( + "/paddlespeech/tts", response_model=Union[TTSResponse, ErrorResponse]) def tts(request_body: TTSRequest): """tts api diff --git a/speechserving/speechserving/utils/__init__.py b/speechserving/speechserving/utils/__init__.py index e69de29b..97043fd7 100644 --- a/speechserving/speechserving/utils/__init__.py +++ b/speechserving/speechserving/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/speechserving/speechserving/utils/config.py b/speechserving/speechserving/utils/config.py index 513c16f6..8c75f536 100644 --- a/speechserving/speechserving/utils/config.py +++ b/speechserving/speechserving/utils/config.py @@ -15,7 +15,7 @@ import yaml from yacs.config import CfgNode -def get_config(config_file): +def get_config(config_file: str): """[summary] Args: diff --git a/speechserving/speechserving/utils/errors.py b/speechserving/speechserving/utils/errors.py index aa858cb0..17ff7551 100644 --- a/speechserving/speechserving/utils/errors.py +++ b/speechserving/speechserving/utils/errors.py @@ -52,6 +52,6 @@ def failed_response(code, msg=""): if not msg: msg = ErrorMsg.get(code, "Unknown error occurred.") - res = {"success": False, "code": int(code), "message": {"global": msg}} + res = {"success": False, "code": int(code), "message": {"description": msg}} return Response(content=json.dumps(res), media_type="application/json") diff --git a/speechserving/speechserving/utils/util.py b/speechserving/speechserving/utils/util.py index cf568572..e9104fa2 100644 --- a/speechserving/speechserving/utils/util.py +++ b/speechserving/speechserving/utils/util.py @@ -13,7 +13,7 @@ import base64 -def wav2base64(wav_file): +def wav2base64(wav_file: str): """ read wave file and covert to base64 string """ @@ -23,12 +23,10 @@ def wav2base64(wav_file): return base64_string -def base64towav(base64_string): +def base64towav(base64_string: str): pass - - def self_check(): """ self check resource """ -- GitLab