diff --git a/speechserving/speechserving/conf/application.yaml b/speechserving/speechserving/conf/application.yaml index 8c4d9bc62d2ca4dac2676c48ca7154e85e414c35..c8d71f2f6ad816e9848096e84c637c0069757594 100644 --- a/speechserving/speechserving/conf/application.yaml +++ b/speechserving/speechserving/conf/application.yaml @@ -10,6 +10,8 @@ port: 8090 # CONFIG FILE # ################################################################## # add engine type (Options: asr, tts) and config file here. + engine_backend: asr: 'conf/asr/asr.yaml' tts: 'conf/tts/tts.yaml' + diff --git a/speechserving/speechserving/conf/asr/asr.yaml b/speechserving/speechserving/conf/asr/asr.yaml index cfa3a68f382b439326b8f2320b4adca92efa4dd5..4c3b0a67e30273681fe765fc2e827f86a21ac380 100644 --- a/speechserving/speechserving/conf/asr/asr.yaml +++ b/speechserving/speechserving/conf/asr/asr.yaml @@ -1,4 +1,7 @@ model: 'conformer_wenetspeech' lang: 'zh' sample_rate: 16000 +cfg_path: +ckpt_path: decode_method: 'attention_rescoring' +force_yes: False diff --git a/speechserving/speechserving/engine/asr/python/asr_engine.py b/speechserving/speechserving/engine/asr/python/asr_engine.py index 8dbc7a3e3156e7c56601c32a090d5c5b0b488422..a18f906a39ea6636ffbd079292458092337b677e 100644 --- a/speechserving/speechserving/engine/asr/python/asr_engine.py +++ b/speechserving/speechserving/engine/asr/python/asr_engine.py @@ -11,29 +11,184 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io +import os +from typing import List +from typing import Optional +from typing import Union + +import librosa +import paddle +import soundfile from engine.base_engine import BaseEngine -from utils.log import logger +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.utility import UpdateConfig from utils.config import get_config __all__ = ['ASREngine'] +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + pass + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error("please input --sr 8000 or --sr 16000") + return False + + logger.info("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + except Exception as e: + logger.exception(e) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + + logger.info("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + self.change_format = True + else: + logger.info("The audio file format is right") + self.change_format = False + + return True + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + + # Get the object for feature extraction + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + audio, _ = self.collate_fn_test.process_utterance( + audio_file=audio_file, transcript=" ") + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + # vocab_list = collate_fn_test.vocab_list + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + logger.info("get the preprocess conf") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + logger.info("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample(audio, audio_sample_rate, + self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.info(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + else: + raise Exception("wrong type") + + class ASREngine(BaseEngine): + """ASR server engine + + Args: + metaclass: Defaults to Singleton. + """ + def __init__(self): super(ASREngine, self).__init__() - def init(self, config_file: str): - self.config_file = config_file - self.executor = None + def init(self, config_file: str) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ self.input = None self.output = None - config = get_config(self.config_file) - pass + self.executor = ASRServerExecutor() - def postprocess(self): - pass + try: + self.config = get_config(config_file) + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + self.config.model, self.config.lang, self.config.sample_rate, + self.config.cfg_path, self.config.decode_method, + self.config.ckpt_path) + except: + logger.info("Initialize ASR server engine Failed.") + return False + + logger.info("Initialize ASR server engine successfully.") + return True - def run(self): - logger.info("start run asr engine") - return "hello world" + def run(self, audio_data): + """engine run + + Args: + audio_data (bytes): base64.b64decode + """ + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start run asr engine") + self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) + self.executor.infer(self.config.model) + self.output = self.executor.postprocess() # Retrieve result of asr. + else: + logger.info("file check failed!") + self.output = None + + def postprocess(self): + """postprocess + """ + return self.output diff --git a/speechserving/speechserving/engine/base_engine.py b/speechserving/speechserving/engine/base_engine.py index 36048dcc69be87585c2460542350df39d2679f85..0cc20209479ea7e033943b799a7e161ac21e3b35 100644 --- a/speechserving/speechserving/engine/base_engine.py +++ b/speechserving/speechserving/engine/base_engine.py @@ -18,6 +18,8 @@ from typing import Union from pattern_singleton import Singleton +__all__ = ['BaseEngine'] + class BaseEngine(metaclass=Singleton): """ diff --git a/speechserving/speechserving/engine/engine_factory.py b/speechserving/speechserving/engine/engine_factory.py index 336a9a6f8b007c5f9687f9c3547a450159c86c86..bc0c45656e1fabfee7c4216a6da858a68a66fc39 100644 --- a/speechserving/speechserving/engine/engine_factory.py +++ b/speechserving/speechserving/engine/engine_factory.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Text + from engine.asr.python.asr_engine import ASREngine from engine.tts.python.tts_engine import TTSEngine +__all__ = ['EngineFactory'] + + class EngineFactory(object): @staticmethod - def get_engine(engine_name): + def get_engine(engine_name: Text): if engine_name == 'asr': return ASREngine() elif engine_name == 'tts': diff --git a/speechserving/speechserving/main.py b/speechserving/speechserving/main.py index 3b367418f28d53e773bbddf8ff61d1c03ca5a366..6d4891c7dcb3e7a732efd09d97417d4ad115f21e 100644 --- a/speechserving/speechserving/main.py +++ b/speechserving/speechserving/main.py @@ -12,21 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse + import uvicorn import yaml +from engine.engine_factory import EngineFactory from fastapi import FastAPI - from restful.api import setup_router -from utils.log import logger + from utils.config import get_config -from engine.engine_factory import EngineFactory +from utils.log import logger app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") def init(config): - """ system initialization + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: """ # init api api_list = list(config.engine_backend) @@ -34,10 +41,11 @@ def init(config): app.include_router(api_router) # init engine - engine_list = [] + engine_pool = [] for engine in config.engine_backend: - engine_list.append(EngineFactory.get_engine(engine_name=engine)) - engine_list[-1].init(config_file=config.engine_backend[engine]) + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init(config_file=config.engine_backend[engine]): + return False return True diff --git a/speechserving/speechserving/restful/api.py b/speechserving/speechserving/restful/api.py index c5539f2431027a19e267974599dbd42a1ee58162..bdff935ac977dcff4e14a9dee9d1c030a6581562 100644 --- a/speechserving/speechserving/restful/api.py +++ b/speechserving/speechserving/restful/api.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List + from fastapi import APIRouter -from .tts_api import router as tts_router from .asr_api import router as asr_router +from .tts_api import router as tts_router _router = APIRouter() + def setup_router(api_list: List): for api_name in api_list: @@ -30,4 +32,3 @@ def setup_router(api_list: List): pass return _router - diff --git a/speechserving/speechserving/restful/asr_api.py b/speechserving/speechserving/restful/asr_api.py index 9d97b3803755a66cf35ec38a5e26430f9ba31a98..c63cd76c2ef6a2318306e4fd72d02dc0b41720d9 100644 --- a/speechserving/speechserving/restful/asr_api.py +++ b/speechserving/speechserving/restful/asr_api.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fastapi import APIRouter import base64 +import traceback +from typing import Union from engine.asr.python.asr_engine import ASREngine -from .response import ASRResponse -from .request import ASRRequest +from fastapi import APIRouter +from .request import ASRRequest +from .response import ASRResponse +from .response import ErrorResponse +from utils.errors import ErrorCode +from utils.errors import failed_response +from utils.exception import ServerBaseException router = APIRouter() + @router.get('/paddlespeech/asr/help') def help(): """help @@ -28,10 +35,23 @@ def help(): Returns: json: [description] """ - return {'hello': 'world'} + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "asr server", + "input": "base64 string of wavfile", + "output": "transcription" + } + } + return response -@router.post("/paddlespeech/asr", response_model=ASRResponse) +@router.post( + "/paddlespeech/asr", response_model=Union[ASRResponse, ErrorResponse]) def asr(request_body: ASRRequest): """asr api @@ -41,22 +61,28 @@ def asr(request_body: ASRRequest): Returns: json: [description] """ - # single - asr_engine = ASREngine() - print("asr_engine id :" ,id(asr_engine)) - - asr_results = asr_engine.run() - asr_engine.postprocess() - - json_body = { - "success": True, - "code": 0, - "message": { - "description": "success" - }, - "result": { - "transcription": asr_results - } - } - - return json_body + try: + # single + audio_data = base64.b64decode(request_body.audio) + asr_engine = ASREngine() + asr_engine.run(audio_data) + asr_results = asr_engine.postprocess() + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "transcription": asr_results + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/speechserving/speechserving/restful/request.py b/speechserving/speechserving/restful/request.py index 9ebc6d752aaee8dc5897091593a3a8c5a6e88511..2be5f0e546dee6c1c042820ac1a3838a446e23ea 100644 --- a/speechserving/speechserving/restful/request.py +++ b/speechserving/speechserving/restful/request.py @@ -19,7 +19,6 @@ from pydantic import BaseModel __all__ = ['ASRRequest', 'TTSRequest'] - #****************************************************************************************/ #************************************ ASR request ***************************************/ #****************************************************************************************/ @@ -31,14 +30,14 @@ class ASRRequest(BaseModel): "audio_format": "wav", "sample_rate": 16000, "lang": "zh_cn", - "ptt":false + "punc":false } """ audio: str audio_format: str sample_rate: int lang: str - ptt: Optional[bool] = None + punc: Optional[bool] = None #****************************************************************************************/ diff --git a/speechserving/speechserving/restful/response.py b/speechserving/speechserving/restful/response.py index db24f5310b18609a5ef3d330cb7b6641d62af41d..ab5e395ba6914482e320d13abf2744e2fef71ec0 100644 --- a/speechserving/speechserving/restful/response.py +++ b/speechserving/speechserving/restful/response.py @@ -86,3 +86,22 @@ class TTSResponse(BaseModel): code: int message: Message result: TTSResult + + +#****************************************************************************************/ +#********************************** Error response **************************************/ +#****************************************************************************************/ +class ErrorResponse(BaseModel): + """ + response example + { + "success": false, + "code": 0, + "message": { + "description": "Unknown error occurred." + } + } + """ + success: bool + code: int + message: Message diff --git a/speechserving/speechserving/restful/tts_api.py b/speechserving/speechserving/restful/tts_api.py index e9dcfa16fa17e736b8ee79fbd8ee170a6fcacc0a..a160e31dc1abc8d16cc2ff99a6a3db3f37296fe5 100644 --- a/speechserving/speechserving/restful/tts_api.py +++ b/speechserving/speechserving/restful/tts_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import traceback +from typing import Union from engine.tts.python.tts_engine import TTSEngine from fastapi import APIRouter from .request import TTSRequest +from .response import ErrorResponse from .response import TTSResponse from utils.errors import ErrorCode from utils.errors import failed_response @@ -47,7 +49,8 @@ def help(): return response -@router.post("/paddlespeech/tts", response_model=TTSResponse) +@router.post( + "/paddlespeech/tts", response_model=Union[TTSResponse, ErrorResponse]) def tts(request_body: TTSRequest): """tts api diff --git a/speechserving/speechserving/utils/__init__.py b/speechserving/speechserving/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..97043fd7ba6885aac81cad5a49924c23c67d4d47 100644 --- a/speechserving/speechserving/utils/__init__.py +++ b/speechserving/speechserving/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/speechserving/speechserving/utils/config.py b/speechserving/speechserving/utils/config.py index 513c16f66c4bab0eab5885c601a4f44b94b8ba28..8c75f536f5de654f1a09fa82187cfef4ef442e90 100644 --- a/speechserving/speechserving/utils/config.py +++ b/speechserving/speechserving/utils/config.py @@ -15,7 +15,7 @@ import yaml from yacs.config import CfgNode -def get_config(config_file): +def get_config(config_file: str): """[summary] Args: diff --git a/speechserving/speechserving/utils/errors.py b/speechserving/speechserving/utils/errors.py index aa858cb083ad8476af35a742f0d4376c5ac13417..17ff75512cd447648ecedf9238809a42743b708c 100644 --- a/speechserving/speechserving/utils/errors.py +++ b/speechserving/speechserving/utils/errors.py @@ -52,6 +52,6 @@ def failed_response(code, msg=""): if not msg: msg = ErrorMsg.get(code, "Unknown error occurred.") - res = {"success": False, "code": int(code), "message": {"global": msg}} + res = {"success": False, "code": int(code), "message": {"description": msg}} return Response(content=json.dumps(res), media_type="application/json") diff --git a/speechserving/speechserving/utils/util.py b/speechserving/speechserving/utils/util.py index cf56857213cec7f922203fe32d079c8f80bd9611..e9104fa2d56283c48304d4676fae19e8dccd1ba5 100644 --- a/speechserving/speechserving/utils/util.py +++ b/speechserving/speechserving/utils/util.py @@ -13,7 +13,7 @@ import base64 -def wav2base64(wav_file): +def wav2base64(wav_file: str): """ read wave file and covert to base64 string """ @@ -23,12 +23,10 @@ def wav2base64(wav_file): return base64_string -def base64towav(base64_string): +def base64towav(base64_string: str): pass - - def self_check(): """ self check resource """ diff --git a/speechserving/tests/16_audio.wav b/speechserving/tests/16_audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..3cfa5074efaea618684e3ca7b497a2b1f33fa7e4 Binary files /dev/null and b/speechserving/tests/16_audio.wav differ diff --git a/speechserving/tests/http_client.py b/speechserving/tests/http_client.py index 3787d764091ca80dd967760ea493bb414b8a7c30..14adb5741989790140fa509bb4e6eeca1b48546f 100644 --- a/speechserving/tests/http_client.py +++ b/speechserving/tests/http_client.py @@ -14,8 +14,8 @@ import requests import json import time import base64 +import io -import argparse def readwav2base64(wav_file): """ @@ -27,7 +27,7 @@ def readwav2base64(wav_file): return base64_string -def main(args): +def main(): """ main func """ @@ -36,11 +36,11 @@ def main(args): # start Timestamp time_start=time.time() - # test_audio_dir = "test_data/16_audio.wav" - # audio = readwav2base64(test_audio_dir) + test_audio_dir = "./16_audio.wav" + audio = readwav2base64(test_audio_dir) data = { - "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf", + "audio": audio, "audio_format": "wav", "sample_rate": 16000, "lang": "zh_cn", @@ -55,12 +55,5 @@ def main(args): print(r.json()) - - if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model_type", action="store", - help="model type: u2, dp2", default="dp2") - args = parser.parse_args() - - main(args) + main()