diff --git a/speechserving/speechserving/conf/application.yaml b/speechserving/speechserving/conf/application.yaml index 8c4d9bc62d2ca4dac2676c48ca7154e85e414c35..c8d71f2f6ad816e9848096e84c637c0069757594 100644 --- a/speechserving/speechserving/conf/application.yaml +++ b/speechserving/speechserving/conf/application.yaml @@ -10,6 +10,8 @@ port: 8090 # CONFIG FILE # ################################################################## # add engine type (Options: asr, tts) and config file here. + engine_backend: asr: 'conf/asr/asr.yaml' tts: 'conf/tts/tts.yaml' + diff --git a/speechserving/speechserving/conf/asr/asr.yaml b/speechserving/speechserving/conf/asr/asr.yaml index 39df2548771517f2bd4ee3764516b7aec50375b1..4c3b0a67e30273681fe765fc2e827f86a21ac380 100644 --- a/speechserving/speechserving/conf/asr/asr.yaml +++ b/speechserving/speechserving/conf/asr/asr.yaml @@ -1,7 +1,7 @@ model: 'conformer_wenetspeech' lang: 'zh' sample_rate: 16000 -cfg_path: "/home/users/zhangyinhui/.paddlespeech/models/conformer_wenetspeech-zh-16k/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar/model.yaml" -ckpt_path: "/home/users/zhangyinhui/.paddlespeech/models/conformer_wenetspeech-zh-16k/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar/exp/conformer/checkpoints/wenetspeech" +cfg_path: +ckpt_path: decode_method: 'attention_rescoring' force_yes: False diff --git a/speechserving/speechserving/engine/asr/python/asr_engine.py b/speechserving/speechserving/engine/asr/python/asr_engine.py index bb1596af44de1dd849b64eefeee020d8b4df5f17..e8289332167bf6cb31db08de54c66860f7504522 100644 --- a/speechserving/speechserving/engine/asr/python/asr_engine.py +++ b/speechserving/speechserving/engine/asr/python/asr_engine.py @@ -11,23 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import paddle import io -import soundfile import os -import librosa from typing import List from typing import Optional from typing import Union -from paddlespeech.cli.log import logger +import librosa +import paddle +import soundfile +from engine.base_engine import BaseEngine + from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig - -from engine.base_engine import BaseEngine from utils.config import get_config __all__ = ['ASREngine'] @@ -141,42 +141,55 @@ class ASREngine(BaseEngine): Args: metaclass: Defaults to Singleton. """ + def __init__(self): super(ASREngine, self).__init__() - def init(self, config_file: str): + def init(self, config_file: str) -> bool: + """init engine resource + + Args: + config_file (str): config file + Returns: + bool: init failed or success + """ + self.input = None + self.output = None self.executor = ASRServerExecutor() - self.config = get_config(config_file) - paddle.set_device(paddle.get_device()) - self.executor._init_from_path( - self.config.model, - self.config.lang, - self.config.sample_rate, - self.config.cfg_path, - self.config.decode_method, - self.config.ckpt_path) + try: + self.config = get_config(config_file) + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + self.config.model, self.config.lang, self.config.sample_rate, + self.config.cfg_path, self.config.decode_method, + self.config.ckpt_path) + except: + logger.info("Initialize ASR server engine Failed.") + return False logger.info("Initialize ASR server engine successfully.") - - self.input = None - self.output = None + return True def run(self, audio_data): + """engine run - if self.executor._check(io.BytesIO(audio_data), self.config.sample_rate, self.config.force_yes): + Args: + audio_data (bytes): base64.b64decode + """ + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start run asr engine") self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) self.executor.infer(self.config.model) - self.output = self.executor.postprocess() # Retrieve result of asr. + self.output = self.executor.postprocess() # Retrieve result of asr. else: logger.info("file check failed!") - - logger.info("start run asr engine") + self.output = None def postprocess(self): - + """postprocess + """ return self.output - - - diff --git a/speechserving/speechserving/engine/base_engine.py b/speechserving/speechserving/engine/base_engine.py index 36048dcc69be87585c2460542350df39d2679f85..0cc20209479ea7e033943b799a7e161ac21e3b35 100644 --- a/speechserving/speechserving/engine/base_engine.py +++ b/speechserving/speechserving/engine/base_engine.py @@ -18,6 +18,8 @@ from typing import Union from pattern_singleton import Singleton +__all__ = ['BaseEngine'] + class BaseEngine(metaclass=Singleton): """ diff --git a/speechserving/speechserving/engine/engine_factory.py b/speechserving/speechserving/engine/engine_factory.py index 336a9a6f8b007c5f9687f9c3547a450159c86c86..2b9f9db70d1712cc681e5d0068304915c7479c7a 100644 --- a/speechserving/speechserving/engine/engine_factory.py +++ b/speechserving/speechserving/engine/engine_factory.py @@ -14,10 +14,12 @@ from engine.asr.python.asr_engine import ASREngine from engine.tts.python.tts_engine import TTSEngine +__all__ = ['EngineFactory'] + class EngineFactory(object): @staticmethod - def get_engine(engine_name): + def get_engine(engine_name: str): if engine_name == 'asr': return ASREngine() elif engine_name == 'tts': diff --git a/speechserving/speechserving/main.py b/speechserving/speechserving/main.py index 3b367418f28d53e773bbddf8ff61d1c03ca5a366..6d4891c7dcb3e7a732efd09d97417d4ad115f21e 100644 --- a/speechserving/speechserving/main.py +++ b/speechserving/speechserving/main.py @@ -12,21 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse + import uvicorn import yaml +from engine.engine_factory import EngineFactory from fastapi import FastAPI - from restful.api import setup_router -from utils.log import logger + from utils.config import get_config -from engine.engine_factory import EngineFactory +from utils.log import logger app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") def init(config): - """ system initialization + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: """ # init api api_list = list(config.engine_backend) @@ -34,10 +41,11 @@ def init(config): app.include_router(api_router) # init engine - engine_list = [] + engine_pool = [] for engine in config.engine_backend: - engine_list.append(EngineFactory.get_engine(engine_name=engine)) - engine_list[-1].init(config_file=config.engine_backend[engine]) + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init(config_file=config.engine_backend[engine]): + return False return True diff --git a/speechserving/speechserving/restful/api.py b/speechserving/speechserving/restful/api.py index c5539f2431027a19e267974599dbd42a1ee58162..bdff935ac977dcff4e14a9dee9d1c030a6581562 100644 --- a/speechserving/speechserving/restful/api.py +++ b/speechserving/speechserving/restful/api.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List + from fastapi import APIRouter -from .tts_api import router as tts_router from .asr_api import router as asr_router +from .tts_api import router as tts_router _router = APIRouter() + def setup_router(api_list: List): for api_name in api_list: @@ -30,4 +32,3 @@ def setup_router(api_list: List): pass return _router - diff --git a/speechserving/speechserving/restful/asr_api.py b/speechserving/speechserving/restful/asr_api.py index ab2c8048af4f1e00caeb5ac3c7ed2037be79e962..6ac647bc9a67a9b31dbcf1f8a80539a1a22f6bee 100644 --- a/speechserving/speechserving/restful/asr_api.py +++ b/speechserving/speechserving/restful/asr_api.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fastapi import APIRouter import base64 +import traceback +from typing import Union from engine.asr.python.asr_engine import ASREngine -from .response import ASRResponse -from .request import ASRRequest +from fastapi import APIRouter +from .request import ASRRequest +from .response import ASRResponse +from .response import ErrorResponse +from utils.errors import ErrorCode +from utils.errors import failed_response +from utils.exception import ServerBaseException router = APIRouter() + @router.get('/paddlespeech/asr/help') def help(): """help @@ -28,10 +35,23 @@ def help(): Returns: json: [description] """ - return {'hello': 'world'} + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "tts server", + "input": "base64 string of wavfile", + "output": "transcription" + } + } + return response -@router.post("/paddlespeech/asr", response_model=ASRResponse) +@router.post( + "/paddlespeech/asr", response_model=Union[ASRResponse, ErrorResponse]) def asr(request_body: ASRRequest): """asr api @@ -41,21 +61,28 @@ def asr(request_body: ASRRequest): Returns: json: [description] """ - audio_data = base64.b64decode(request_body.audio) - # single - asr_engine = ASREngine() - asr_engine.run(audio_data) - asr_results = asr_engine.postprocess() - - json_body = { - "success": True, - "code": 0, - "message": { - "description": "success" - }, - "result": { - "transcription": asr_results - } - } - - return json_body + try: + # single + audio_data = base64.b64decode(request_body.audio) + asr_engine = ASREngine() + asr_engine.run(audio_data) + asr_results = asr_engine.postprocess() + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "transcription": asr_results + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/speechserving/speechserving/restful/request.py b/speechserving/speechserving/restful/request.py index 9ebc6d752aaee8dc5897091593a3a8c5a6e88511..2be5f0e546dee6c1c042820ac1a3838a446e23ea 100644 --- a/speechserving/speechserving/restful/request.py +++ b/speechserving/speechserving/restful/request.py @@ -19,7 +19,6 @@ from pydantic import BaseModel __all__ = ['ASRRequest', 'TTSRequest'] - #****************************************************************************************/ #************************************ ASR request ***************************************/ #****************************************************************************************/ @@ -31,14 +30,14 @@ class ASRRequest(BaseModel): "audio_format": "wav", "sample_rate": 16000, "lang": "zh_cn", - "ptt":false + "punc":false } """ audio: str audio_format: str sample_rate: int lang: str - ptt: Optional[bool] = None + punc: Optional[bool] = None #****************************************************************************************/ diff --git a/speechserving/speechserving/restful/response.py b/speechserving/speechserving/restful/response.py index db24f5310b18609a5ef3d330cb7b6641d62af41d..ab5e395ba6914482e320d13abf2744e2fef71ec0 100644 --- a/speechserving/speechserving/restful/response.py +++ b/speechserving/speechserving/restful/response.py @@ -86,3 +86,22 @@ class TTSResponse(BaseModel): code: int message: Message result: TTSResult + + +#****************************************************************************************/ +#********************************** Error response **************************************/ +#****************************************************************************************/ +class ErrorResponse(BaseModel): + """ + response example + { + "success": false, + "code": 0, + "message": { + "description": "Unknown error occurred." + } + } + """ + success: bool + code: int + message: Message diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index e9dcfa16fa17e736b8ee79fbd8ee170a6fcacc0a..a160e31dc1abc8d16cc2ff99a6a3db3f37296fe5 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import traceback +from typing import Union from engine.tts.python.tts_engine import TTSEngine from fastapi import APIRouter from .request import TTSRequest +from .response import ErrorResponse from .response import TTSResponse from utils.errors import ErrorCode from utils.errors import failed_response @@ -47,7 +49,8 @@ def help(): return response -@router.post("/paddlespeech/tts", response_model=TTSResponse) +@router.post( + "/paddlespeech/tts", response_model=Union[TTSResponse, ErrorResponse]) def tts(request_body: TTSRequest): """tts api diff --git a/speechserving/speechserving/utils/__init__.py b/speechserving/speechserving/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..97043fd7ba6885aac81cad5a49924c23c67d4d47 100644 --- a/speechserving/speechserving/utils/__init__.py +++ b/speechserving/speechserving/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/speechserving/speechserving/utils/config.py b/speechserving/speechserving/utils/config.py index 513c16f66c4bab0eab5885c601a4f44b94b8ba28..8c75f536f5de654f1a09fa82187cfef4ef442e90 100644 --- a/speechserving/speechserving/utils/config.py +++ b/speechserving/speechserving/utils/config.py @@ -15,7 +15,7 @@ import yaml from yacs.config import CfgNode -def get_config(config_file): +def get_config(config_file: str): """[summary] Args: diff --git a/speechserving/speechserving/utils/errors.py b/speechserving/speechserving/utils/errors.py index aa858cb083ad8476af35a742f0d4376c5ac13417..17ff75512cd447648ecedf9238809a42743b708c 100644 --- a/speechserving/speechserving/utils/errors.py +++ b/speechserving/speechserving/utils/errors.py @@ -52,6 +52,6 @@ def failed_response(code, msg=""): if not msg: msg = ErrorMsg.get(code, "Unknown error occurred.") - res = {"success": False, "code": int(code), "message": {"global": msg}} + res = {"success": False, "code": int(code), "message": {"description": msg}} return Response(content=json.dumps(res), media_type="application/json") diff --git a/speechserving/speechserving/utils/util.py b/speechserving/speechserving/utils/util.py index cf56857213cec7f922203fe32d079c8f80bd9611..e9104fa2d56283c48304d4676fae19e8dccd1ba5 100644 --- a/speechserving/speechserving/utils/util.py +++ b/speechserving/speechserving/utils/util.py @@ -13,7 +13,7 @@ import base64 -def wav2base64(wav_file): +def wav2base64(wav_file: str): """ read wave file and covert to base64 string """ @@ -23,12 +23,10 @@ def wav2base64(wav_file): return base64_string -def base64towav(base64_string): +def base64towav(base64_string: str): pass - - def self_check(): """ self check resource """