diff --git a/demos/speech_server/start_multi_progress_server.py b/demos/speech_server/start_multi_progress_server.py new file mode 100644 index 0000000000000000000000000000000000000000..5e86befb7d31c5109ff2c44ac4777d38e9af752b --- /dev/null +++ b/demos/speech_server/start_multi_progress_server.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import warnings + +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from paddlespeech.server.engine.engine_pool import init_engine_pool +from paddlespeech.server.restful.api import setup_router as setup_http_router +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router +warnings.filterwarnings("ignore") +import sys + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + +# change yaml file here +config_file = "./conf/application.yaml" +config = get_config(config_file) + +# init engine +if not init_engine_pool(config): + print("Failed to init engine.") + sys.exit(-1) + +# get api_router +api_list = list(engine.split("_")[0] for engine in config.engine_list) +if config.protocol == "websocket": + api_router = setup_ws_router(api_list) +elif config.protocol == "http": + api_router = setup_http_router(api_list) +else: + raise Exception("unsupported protocol") + sys.exit(-1) + +# app needs to operate outside the main function +app.include_router(api_router) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--workers", type=int, help="workers of server", default=1) + args = parser.parse_args() + + uvicorn.run( + "start_multi_progress_server:app", + host=config.host, + port=config.port, + debug=True, + workers=args.workers) diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index e59f17d38820374a620e8d9eb78daf060412130d..a3a29fef962bf09b66df4e3c562dcab1790b5a79 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -26,6 +26,7 @@ from ..util import cli_server_register from ..util import stats_wrapper from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import init_engine_pool +from paddlespeech.server.engine.engine_warmup import warm_up from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.utils.config import get_config from paddlespeech.server.ws.api import setup_router as setup_ws_router @@ -86,6 +87,11 @@ class ServerExecutor(BaseExecutor): if not init_engine_pool(config): return False + # warm up + for engine_and_type in config.engine_list: + if not warm_up(engine_and_type): + return False + return True def execute(self, argv: List[str]) -> bool: diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py index e275f1088f648df62947ded43f297cbb8d2c70c2..4234e1e2d41b40c5d896ebf86e7781f34e24c95c 100644 --- a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py @@ -30,7 +30,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model -__all__ = ['ASREngine'] +__all__ = ['ASREngine', 'PaddleASRConnectionHandler'] class ASRServerExecutor(ASRExecutor): @@ -50,7 +50,7 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.max_len = 50 sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str if cfg_path is None or am_model is None or am_params is None: @@ -172,10 +172,23 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = ASRServerExecutor() self.config = config + self.engine_type = "inference" + + try: + if self.config.am_predictor_conf.device is not None: + self.device = self.config.am_predictor_conf.device + else: + self.device = paddle.get_device() + + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error(e) + return False self.executor._init_from_path( model_type=self.config.model_type, @@ -190,22 +203,42 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True + +class PaddleASRConnectionHandler(ASRServerExecutor): + def __init__(self, asr_engine): + """The PaddleSpeech ASR Server Connection Handler + This connection process every asr server request + Args: + asr_engine (ASREngine): The ASR engine + """ + super().__init__() + self.input = None + self.output = None + self.asr_engine = asr_engine + self.executor = self.asr_engine.executor + self.config = self.executor.config + self.max_len = self.executor.max_len + self.decoder = self.executor.decoder + self.am_predictor = self.executor.am_predictor + self.text_feature = self.executor.text_feature + self.collate_fn_test = self.executor.collate_fn_test + def run(self, audio_data): """engine run Args: audio_data (bytes): base64.b64decode """ - if self.executor._check( - io.BytesIO(audio_data), self.config.sample_rate, - self.config.force_yes): + if self._check( + io.BytesIO(audio_data), self.asr_engine.config.sample_rate, + self.asr_engine.config.force_yes): logger.info("start running asr engine") - self.executor.preprocess(self.config.model_type, - io.BytesIO(audio_data)) + self.preprocess(self.asr_engine.config.model_type, + io.BytesIO(audio_data)) st = time.time() - self.executor.infer(self.config.model_type) + self.infer(self.asr_engine.config.model_type) infer_time = time.time() - st - self.output = self.executor.postprocess() # Retrieve result of asr. + self.output = self.postprocess() # Retrieve result of asr. logger.info("end inferring asr engine") else: logger.info("file check failed!") @@ -213,8 +246,3 @@ class ASREngine(BaseEngine): logger.info("inference time: {}".format(infer_time)) logger.info("asr engine type: paddle inference") - - def postprocess(self): - """postprocess - """ - return self.output diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index d60a5feaeca6caa5e385f872872104df2a8aa124..f9cc3a6650cdaff91fdf5c52ffa285aa4d7f2d16 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import sys import time import paddle @@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.server.engine.base_engine import BaseEngine -__all__ = ['ASREngine'] +__all__ = ['ASREngine', 'PaddleASRConnectionHandler'] class ASRServerExecutor(ASRExecutor): @@ -48,20 +49,23 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = ASRServerExecutor() self.config = config + self.engine_type = "python" + try: - if self.config.device: + if self.config.device is not None: self.device = self.config.device else: self.device = paddle.get_device() + paddle.set_device(self.device) - except BaseException: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) + logger.error(e) + return False self.executor._init_from_path( self.config.model, self.config.lang, self.config.sample_rate, @@ -72,6 +76,24 @@ class ASREngine(BaseEngine): (self.device)) return True + +class PaddleASRConnectionHandler(ASRServerExecutor): + def __init__(self, asr_engine): + """The PaddleSpeech ASR Server Connection Handler + This connection process every asr server request + Args: + asr_engine (ASREngine): The ASR engine + """ + super().__init__() + self.input = None + self.output = None + self.asr_engine = asr_engine + self.executor = self.asr_engine.executor + self.max_len = self.executor.max_len + self.text_feature = self.executor.text_feature + self.model = self.executor.model + self.config = self.executor.config + def run(self, audio_data): """engine run @@ -79,17 +101,16 @@ class ASREngine(BaseEngine): audio_data (bytes): base64.b64decode """ try: - if self.executor._check( - io.BytesIO(audio_data), self.config.sample_rate, - self.config.force_yes): + if self._check( + io.BytesIO(audio_data), self.asr_engine.config.sample_rate, + self.asr_engine.config.force_yes): logger.info("start run asr engine") - self.executor.preprocess(self.config.model, - io.BytesIO(audio_data)) + self.preprocess(self.asr_engine.config.model, + io.BytesIO(audio_data)) st = time.time() - self.executor.infer(self.config.model) + self.infer(self.asr_engine.config.model) infer_time = time.time() - st - self.output = self.executor.postprocess( - ) # Retrieve result of asr. + self.output = self.postprocess() # Retrieve result of asr. else: logger.info("file check failed!") self.output = None @@ -98,8 +119,4 @@ class ASREngine(BaseEngine): logger.info("asr engine type: python") except Exception as e: logger.info(e) - - def postprocess(self): - """postprocess - """ - return self.output + sys.exit(-1) diff --git a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py index 0906c2412d36f2d27393731da18e994772c2addd..44750c4747ed2ac3e01c1423d0ca65941d3b833e 100644 --- a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py +++ b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py @@ -14,6 +14,7 @@ import io import os import time +from collections import OrderedDict from typing import Optional import numpy as np @@ -27,7 +28,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model -__all__ = ['CLSEngine'] +__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler'] class CLSServerExecutor(CLSExecutor): @@ -119,14 +120,55 @@ class CLSEngine(BaseEngine): """ self.executor = CLSServerExecutor() self.config = config - self.executor._init_from_path( - self.config.model_type, self.config.cfg_path, - self.config.model_path, self.config.params_path, - self.config.label_file, self.config.predictor_conf) + self.engine_type = "inference" + + try: + if self.config.predictor_conf.device is not None: + self.device = self.config.predictor_conf.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error(e) + return False + + try: + self.executor._init_from_path( + self.config.model_type, self.config.cfg_path, + self.config.model_path, self.config.params_path, + self.config.label_file, self.config.predictor_conf) + + except Exception as e: + logger.error("Initialize CLS server engine Failed.") + logger.error(e) + return False logger.info("Initialize CLS server engine successfully.") return True + +class PaddleCLSConnectionHandler(CLSServerExecutor): + def __init__(self, cls_engine): + """The PaddleSpeech CLS Server Connection Handler + This connection process every cls server request + Args: + cls_engine (CLSEngine): The CLS engine + """ + super().__init__() + logger.info( + "Create PaddleCLSConnectionHandler to process the cls request") + + self._inputs = OrderedDict() + self._outputs = OrderedDict() + self.cls_engine = cls_engine + self.executor = self.cls_engine.executor + self._conf = self.executor._conf + self._label_list = self.executor._label_list + self.predictor = self.executor.predictor + def run(self, audio_data): """engine run @@ -134,9 +176,9 @@ class CLSEngine(BaseEngine): audio_data (bytes): base64.b64decode """ - self.executor.preprocess(io.BytesIO(audio_data)) + self.preprocess(io.BytesIO(audio_data)) st = time.time() - self.executor.infer() + self.infer() infer_time = time.time() - st logger.info("inference time: {}".format(infer_time)) @@ -145,15 +187,15 @@ class CLSEngine(BaseEngine): def postprocess(self, topk: int): """postprocess """ - assert topk <= len(self.executor._label_list - ), 'Value of topk is larger than number of labels.' + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' - result = np.squeeze(self.executor._outputs['logits'], axis=0) + result = np.squeeze(self._outputs['logits'], axis=0) topk_idx = (-result).argsort()[:topk] topk_results = [] for idx in topk_idx: res = {} - label, score = self.executor._label_list[idx], result[idx] + label, score = self._label_list[idx], result[idx] res['class_name'] = label res['prob'] = score topk_results.append(res) diff --git a/paddlespeech/server/engine/cls/python/cls_engine.py b/paddlespeech/server/engine/cls/python/cls_engine.py index 1a975b0a05b4d0163e47877b5141da529ad5f004..f8d8f20ef215da47c823e2bfef056b2c4ec4bb6d 100644 --- a/paddlespeech/server/engine/cls/python/cls_engine.py +++ b/paddlespeech/server/engine/cls/python/cls_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import io import time -from typing import List +from collections import OrderedDict import paddle @@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor from paddlespeech.cli.log import logger from paddlespeech.server.engine.base_engine import BaseEngine -__all__ = ['CLSEngine'] +__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler'] class CLSServerExecutor(CLSExecutor): @@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor): super().__init__() pass - def get_topk_results(self, topk: int) -> List: - assert topk <= len( - self._label_list), 'Value of topk is larger than number of labels.' - - result = self._outputs['logits'].squeeze(0).numpy() - topk_idx = (-result).argsort()[:topk] - res = {} - topk_results = [] - for idx in topk_idx: - label, score = self._label_list[idx], result[idx] - res['class'] = label - res['prob'] = score - topk_results.append(res) - return topk_results - class CLSEngine(BaseEngine): """CLS server engine @@ -64,42 +49,65 @@ class CLSEngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = CLSServerExecutor() self.config = config + self.engine_type = "python" + try: - if self.config.device: + if self.config.device is not None: self.device = self.config.device else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) + logger.error(e) + return False try: self.executor._init_from_path( self.config.model, self.config.cfg_path, self.config.ckpt_path, self.config.label_file) - except BaseException: + except Exception as e: logger.error("Initialize CLS server engine Failed.") + logger.error(e) return False logger.info("Initialize CLS server engine successfully on device: %s." % (self.device)) return True + +class PaddleCLSConnectionHandler(CLSServerExecutor): + def __init__(self, cls_engine): + """The PaddleSpeech CLS Server Connection Handler + This connection process every cls server request + Args: + cls_engine (CLSEngine): The CLS engine + """ + super().__init__() + logger.info( + "Create PaddleCLSConnectionHandler to process the cls request") + + self._inputs = OrderedDict() + self._outputs = OrderedDict() + self.cls_engine = cls_engine + self.executor = self.cls_engine.executor + self._conf = self.executor._conf + self._label_list = self.executor._label_list + self.model = self.executor.model + def run(self, audio_data): """engine run Args: audio_data (bytes): base64.b64decode """ - self.executor.preprocess(io.BytesIO(audio_data)) + self.preprocess(io.BytesIO(audio_data)) st = time.time() - self.executor.infer() + self.infer() infer_time = time.time() - st logger.info("inference time: {}".format(infer_time)) @@ -108,15 +116,15 @@ class CLSEngine(BaseEngine): def postprocess(self, topk: int): """postprocess """ - assert topk <= len(self.executor._label_list - ), 'Value of topk is larger than number of labels.' + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' - result = self.executor._outputs['logits'].squeeze(0).numpy() + result = self._outputs['logits'].squeeze(0).numpy() topk_idx = (-result).argsort()[:topk] topk_results = [] for idx in topk_idx: res = {} - label, score = self.executor._label_list[idx], result[idx] + label, score = self._label_list[idx], result[idx] res['class_name'] = label res['prob'] = score topk_results.append(res) diff --git a/paddlespeech/server/engine/engine_warmup.py b/paddlespeech/server/engine/engine_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..5f548f71dbe6c673564350a197b314a55710989f --- /dev/null +++ b/paddlespeech/server/engine/engine_warmup.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool + + +def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: + engine_pool = get_engine_pool() + + if "tts" in engine_and_type: + tts_engine = engine_pool['tts'] + flag_online = False + if tts_engine.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + elif tts_engine.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + else: + logger.error("tts engine only support lang: zh or en.") + sys.exit(-1) + + if engine_and_type == "tts_python": + from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler + elif engine_and_type == "tts_inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler + elif engine_and_type == "tts_online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + flag_online = True + elif engine_and_type == "tts_online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + flag_online = True + else: + logger.error("Please check tte engine type.") + + try: + logger.info("Start to warm up tts engine.") + for i in range(warm_up_time): + connection_handler = PaddleTTSConnectionHandler(tts_engine) + if flag_online: + for wav in connection_handler.infer( + text=sentence, + lang=tts_engine.lang, + am=tts_engine.config.am): + logger.info( + f"The first response time of the {i} warm up: {connection_handler.first_response_time} s" + ) + break + + else: + st = time.time() + connection_handler.infer(text=sentence) + et = time.time() + logger.info( + f"The response time of the {i} warm up: {et - st} s") + except Exception as e: + logger.error("Failed to warm up on tts engine.") + logger.error(e) + return False + + else: + pass + + return True diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py index 792442065074af9168f84b1ce695bb484b01e388..fd438da0314881875db0dfabf13ec9e04a8770cf 100644 --- a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -31,18 +31,12 @@ from paddlespeech.server.utils.util import get_chunks from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): - def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample): + def __init__(self): super().__init__() - self.am_block = am_block - self.am_pad = am_pad - self.voc_block = voc_block - self.voc_pad = voc_pad - self.voc_upsample = voc_upsample - self.pretrained_models = pretrained_models def _init_from_path( @@ -161,6 +155,115 @@ class TTSServerExecutor(TTSExecutor): self.frontend = English(phone_vocab_path=self.phones_dict) logger.info("frontend done!") + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super().__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + self.config = config + self.lang = self.config.lang + self.engine_type = "online-onnx" + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + self.am_upsample = 1 + self.voc_upsample = self.config.voc_upsample + + assert ( + self.config.am == "fastspeech2_csmsc_onnx" or + self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" + ) and ( + self.config.voc == "hifigan_csmsc_onnx" or + self.config.voc == "mb_melgan_csmsc_onnx" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + self.config.voc_block > 0 and self.config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + + assert ( + self.config.voc_sample_rate == self.config.am_sample_rate + ), "The sample rate of AM and Vocoder model are different, please check model." + + try: + if self.config.am_sess_conf.device is not None: + self.device = self.config.am_sess_conf.device + elif self.config.voc_sess_conf.device is not None: + self.device = self.config.voc_sess_conf.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + am_sample_rate=self.config.am_sample_rate, + am_sess_conf=self.config.am_sess_conf, + voc=self.config.voc, + voc_ckpt=self.config.voc_ckpt, + voc_sample_rate=self.config.voc_sample_rate, + voc_sess_conf=self.config.voc_sess_conf, + lang=self.config.lang) + + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.config.voc_sess_conf.device)) + logger(e) + return False + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.config.voc_sess_conf.device)) + + return True + + +class PaddleTTSConnectionHandler: + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine + """ + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.am_block = self.tts_engine.am_block + self.am_pad = self.tts_engine.am_pad + self.voc_block = self.tts_engine.voc_block + self.voc_pad = self.tts_engine.voc_pad + self.am_upsample = self.tts_engine.am_upsample + self.voc_upsample = self.tts_engine.voc_upsample + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -189,12 +292,6 @@ class TTSServerExecutor(TTSExecutor): Model inference and result stored in self.output. """ - am_block = self.am_block - am_pad = self.am_pad - am_upsample = 1 - voc_block = self.voc_block - voc_pad = self.voc_pad - voc_upsample = self.voc_upsample # first_flag 用于标记首包 first_flag = 1 get_tone_ids = False @@ -203,7 +300,7 @@ class TTSServerExecutor(TTSExecutor): # front frontend_st = time.time() if lang == 'zh': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) @@ -211,7 +308,7 @@ class TTSServerExecutor(TTSExecutor): if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: @@ -226,7 +323,7 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_csmsc if am == "fastspeech2_csmsc_onnx": # am - mel = self.am_sess.run( + mel = self.executor.am_sess.run( output_names=None, input_feed={'text': part_phone_ids}) mel = mel[0] if first_flag == 1: @@ -234,14 +331,16 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et # voc streaming - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad, + "voc") voc_chunk_num = len(mel_chunks) voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_sess.run( + sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': mel_chunk}) sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, - voc_block, voc_pad, voc_upsample) + self.voc_block, self.voc_pad, + self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -253,7 +352,7 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_cnndecoder_csmsc elif am == "fastspeech2_cnndecoder_csmsc_onnx": # am - orig_hs = self.am_encoder_infer_sess.run( + orig_hs = self.executor.am_encoder_infer_sess.run( None, input_feed={'text': part_phone_ids}) orig_hs = orig_hs[0] @@ -267,9 +366,9 @@ class TTSServerExecutor(TTSExecutor): hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): - am_decoder_output = self.am_decoder_sess.run( + am_decoder_output = self.executor.am_decoder_sess.run( None, input_feed={'xs': hs}) - am_postnet_output = self.am_postnet_sess.run( + am_postnet_output = self.executor.am_postnet_sess.run( None, input_feed={ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) @@ -278,9 +377,11 @@ class TTSServerExecutor(TTSExecutor): am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] - sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) - sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, - am_pad, am_upsample) + sub_mel = denorm(normalized_mel, self.executor.am_mu, + self.executor.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, + self.am_block, self.am_pad, + self.am_upsample) if i == 0: mel_streaming = sub_mel @@ -297,11 +398,11 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] - sub_wav = self.voc_sess.run( + sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': voc_chunk}) - sub_wav = self.depadding(sub_wav[0], voc_chunk_num, - voc_chunk_id, voc_block, - voc_pad, voc_upsample) + sub_wav = self.depadding( + sub_wav[0], voc_chunk_num, voc_chunk_id, + self.voc_block, self.voc_pad, self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -311,9 +412,11 @@ class TTSServerExecutor(TTSExecutor): yield sub_wav voc_chunk_id += 1 - start = max(0, voc_chunk_id * voc_block - voc_pad) - end = min((voc_chunk_id + 1) * voc_block + voc_pad, - mel_len) + start = max( + 0, voc_chunk_id * self.voc_block - self.voc_pad) + end = min( + (voc_chunk_id + 1) * self.voc_block + self.voc_pad, + mel_len) else: logger.error( @@ -322,111 +425,6 @@ class TTSServerExecutor(TTSExecutor): self.final_response_time = time.time() - frontend_st - -class TTSEngine(BaseEngine): - """TTS server engine - - Args: - metaclass: Defaults to Singleton. - """ - - def __init__(self, name=None): - """Initialize TTS server engine - """ - super().__init__() - - def init(self, config: dict) -> bool: - self.config = config - assert ( - self.config.am == "fastspeech2_csmsc_onnx" or - self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" - ) and ( - self.config.voc == "hifigan_csmsc_onnx" or - self.config.voc == "mb_melgan_csmsc_onnx" - ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' - - assert ( - self.config.voc_block > 0 and self.config.voc_pad > 0 - ), "Please set correct voc_block and voc_pad, they should be more than 0." - - assert ( - self.config.voc_sample_rate == self.config.am_sample_rate - ), "The sample rate of AM and Vocoder model are different, please check model." - - self.executor = TTSServerExecutor( - self.config.am_block, self.config.am_pad, self.config.voc_block, - self.config.voc_pad, self.config.voc_upsample) - - try: - if self.config.am_sess_conf.device is not None: - self.device = self.config.am_sess_conf.device - elif self.config.voc_sess_conf.device is not None: - self.device = self.config.voc_sess_conf.device - else: - self.device = paddle.get_device() - paddle.set_device(self.device) - except BaseException as e: - logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" - ) - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - try: - self.executor._init_from_path( - am=self.config.am, - am_ckpt=self.config.am_ckpt, - am_stat=self.config.am_stat, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - am_sample_rate=self.config.am_sample_rate, - am_sess_conf=self.config.am_sess_conf, - voc=self.config.voc, - voc_ckpt=self.config.voc_ckpt, - voc_sample_rate=self.config.voc_sample_rate, - voc_sess_conf=self.config.voc_sess_conf, - lang=self.config.lang) - - except Exception as e: - logger.error("Failed to get model related files.") - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.config.voc_sess_conf.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully on device: %s." % - (self.config.voc_sess_conf.device)) - - return True - - def warm_up(self): - """warm up - """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - for wav in self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ): - logger.info( - f"The first response time of the {i} warm up: {self.executor.first_response_time} s" - ) - break - def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -459,7 +457,7 @@ class TTSEngine(BaseEngine): """ wav_list = [] - for wav in self.executor.infer( + for wav in self.infer( text=sentence, lang=self.config.lang, am=self.config.am, @@ -477,11 +475,9 @@ class TTSEngine(BaseEngine): duration = len(wav_all) / self.config.voc_sample_rate logger.info(f"sentence: {sentence}") logger.info(f"The durations of audio is: {duration} s") + logger.info(f"first response time: {self.first_response_time} s") + logger.info(f"final response time: {self.final_response_time} s") + logger.info(f"RTF: {self.final_response_time / duration}") logger.info( - f"first response time: {self.executor.first_response_time} s") - logger.info( - f"final response time: {self.executor.final_response_time} s") - logger.info(f"RTF: {self.executor.final_response_time / duration}") - logger.info( - f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s," ) diff --git a/paddlespeech/server/engine/tts/online/python/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py index 8dc36f8ef8f6d0d2316e59e8090f43aa2702f8e2..eaa179929f40625bafc35a58de5d30a8808830e6 100644 --- a/paddlespeech/server/engine/tts/online/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -34,16 +34,12 @@ from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.utils.dynamic_import import dynamic_import -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): - def __init__(self, am_block, am_pad, voc_block, voc_pad): + def __init__(self): super().__init__() - self.am_block = am_block - self.am_pad = am_pad - self.voc_block = voc_block - self.voc_pad = voc_pad self.pretrained_models = pretrained_models def get_model_info(self, @@ -205,6 +201,106 @@ class TTSServerExecutor(TTSExecutor): self.voc_inference.eval() print("voc done!") + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super().__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + self.config = config + self.lang = self.config.lang + self.engine_type = "online" + + assert ( + config.am == "fastspeech2_csmsc" or + config.am == "fastspeech2_cnndecoder_csmsc" + ) and ( + config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + config.voc_block > 0 and config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + + try: + if self.config.device is not None: + self.device = self.config.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + self.am_upsample = 1 + self.voc_upsample = self.executor.voc_config.n_shift + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.device)) + + return True + + +class PaddleTTSConnectionHandler: + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine + """ + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.am_block = self.tts_engine.am_block + self.am_pad = self.tts_engine.am_pad + self.voc_block = self.tts_engine.voc_block + self.voc_pad = self.tts_engine.voc_pad + self.am_upsample = self.tts_engine.am_upsample + self.voc_upsample = self.tts_engine.voc_upsample + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -233,12 +329,6 @@ class TTSServerExecutor(TTSExecutor): Model inference and result stored in self.output. """ - am_block = self.am_block - am_pad = self.am_pad - am_upsample = 1 - voc_block = self.voc_block - voc_pad = self.voc_pad - voc_upsample = self.voc_config.n_shift # first_flag 用于标记首包 first_flag = 1 @@ -246,7 +336,7 @@ class TTSServerExecutor(TTSExecutor): merge_sentences = False frontend_st = time.time() if lang == 'zh': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) @@ -254,7 +344,7 @@ class TTSServerExecutor(TTSExecutor): if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: @@ -269,19 +359,21 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_csmsc if am == "fastspeech2_csmsc": # am - mel = self.am_inference(part_phone_ids) + mel = self.executor.am_inference(part_phone_ids) if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et # voc streaming - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad, + "voc") voc_chunk_num = len(mel_chunks) voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_inference(mel_chunk) + sub_wav = self.executor.voc_inference(mel_chunk) sub_wav = self.depadding(sub_wav, voc_chunk_num, i, - voc_block, voc_pad, voc_upsample) + self.voc_block, self.voc_pad, + self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -293,7 +385,8 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_cnndecoder_csmsc elif am == "fastspeech2_cnndecoder_csmsc": # am - orig_hs = self.am_inference.encoder_infer(part_phone_ids) + orig_hs = self.executor.am_inference.encoder_infer( + part_phone_ids) # streaming voc chunk info mel_len = orig_hs.shape[1] @@ -305,13 +398,15 @@ class TTSServerExecutor(TTSExecutor): hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): - before_outs = self.am_inference.decoder(hs) - after_outs = before_outs + self.am_inference.postnet( + before_outs = self.executor.am_inference.decoder(hs) + after_outs = before_outs + self.executor.am_inference.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) normalized_mel = after_outs[0] - sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) - sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, - am_pad, am_upsample) + sub_mel = denorm(normalized_mel, self.executor.am_mu, + self.executor.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, + self.am_block, self.am_pad, + self.am_upsample) if i == 0: mel_streaming = sub_mel @@ -328,11 +423,11 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] voc_chunk = paddle.to_tensor(voc_chunk) - sub_wav = self.voc_inference(voc_chunk) + sub_wav = self.executor.voc_inference(voc_chunk) - sub_wav = self.depadding(sub_wav, voc_chunk_num, - voc_chunk_id, voc_block, - voc_pad, voc_upsample) + sub_wav = self.depadding( + sub_wav, voc_chunk_num, voc_chunk_id, + self.voc_block, self.voc_pad, self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -342,9 +437,11 @@ class TTSServerExecutor(TTSExecutor): yield sub_wav voc_chunk_id += 1 - start = max(0, voc_chunk_id * voc_block - voc_pad) - end = min((voc_chunk_id + 1) * voc_block + voc_pad, - mel_len) + start = max( + 0, voc_chunk_id * self.voc_block - self.voc_pad) + end = min( + (voc_chunk_id + 1) * self.voc_block + self.voc_pad, + mel_len) else: logger.error( @@ -353,100 +450,6 @@ class TTSServerExecutor(TTSExecutor): self.final_response_time = time.time() - frontend_st - -class TTSEngine(BaseEngine): - """TTS server engine - - Args: - metaclass: Defaults to Singleton. - """ - - def __init__(self, name=None): - """Initialize TTS server engine - """ - super().__init__() - - def init(self, config: dict) -> bool: - self.config = config - assert ( - config.am == "fastspeech2_csmsc" or - config.am == "fastspeech2_cnndecoder_csmsc" - ) and ( - config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" - ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' - - assert ( - config.voc_block > 0 and config.voc_pad > 0 - ), "Please set correct voc_block and voc_pad, they should be more than 0." - - try: - if self.config.device is not None: - self.device = self.config.device - else: - self.device = paddle.get_device() - paddle.set_device(self.device) - except Exception as e: - logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" - ) - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - self.executor = TTSServerExecutor(config.am_block, config.am_pad, - config.voc_block, config.voc_pad) - - try: - self.executor._init_from_path( - am=self.config.am, - am_config=self.config.am_config, - am_ckpt=self.config.am_ckpt, - am_stat=self.config.am_stat, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_config=self.config.voc_config, - voc_ckpt=self.config.voc_ckpt, - voc_stat=self.config.voc_stat, - lang=self.config.lang) - except Exception as e: - logger.error("Failed to get model related files.") - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully on device: %s." % - (self.device)) - return True - - def warm_up(self): - """warm up - """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - for wav in self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ): - logger.info( - f"The first response time of the {i} warm up: {self.executor.first_response_time} s" - ) - break - def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -480,7 +483,7 @@ class TTSEngine(BaseEngine): wav_list = [] - for wav in self.executor.infer( + for wav in self.infer( text=sentence, lang=self.config.lang, am=self.config.am, @@ -496,13 +499,12 @@ class TTSEngine(BaseEngine): wav_all = np.concatenate(wav_list, axis=0) duration = len(wav_all) / self.executor.am_config.fs + logger.info(f"sentence: {sentence}") logger.info(f"The durations of audio is: {duration} s") + logger.info(f"first response time: {self.first_response_time} s") + logger.info(f"final response time: {self.final_response_time} s") + logger.info(f"RTF: {self.final_response_time / duration}") logger.info( - f"first response time: {self.executor.first_response_time} s") - logger.info( - f"final response time: {self.executor.final_response_time} s") - logger.info(f"RTF: {self.executor.final_response_time / duration}") - logger.info( - f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," - ) \ No newline at end of file + f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s," + ) diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index f1ce8b76e2eacd378ccb8657486716ffb5ad4036..1676801e7cd1f9d7a5843a5bf7ac8b339eaf8f54 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -14,6 +14,7 @@ import base64 import io import os +import sys import time from typing import Optional @@ -35,7 +36,7 @@ from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): @@ -245,7 +246,7 @@ class TTSServerExecutor(TTSExecutor): else: wav_all = paddle.concat([wav_all, wav]) self.voc_time += (time.time() - voc_st) - self._outputs['wav'] = wav_all + self._outputs["wav"] = wav_all class TTSEngine(BaseEngine): @@ -263,6 +264,8 @@ class TTSEngine(BaseEngine): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() self.config = config + self.lang = self.config.lang + self.engine_type = "inference" try: if self.config.am_predictor_conf.device is not None: @@ -272,58 +275,59 @@ class TTSEngine(BaseEngine): else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException as e: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False - self.executor._init_from_path( - am=self.config.am, - am_model=self.config.am_model, - am_params=self.config.am_params, - am_sample_rate=self.config.am_sample_rate, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_model=self.config.voc_model, - voc_params=self.config.voc_params, - voc_sample_rate=self.config.voc_sample_rate, - lang=self.config.lang, - am_predictor_conf=self.config.am_predictor_conf, - voc_predictor_conf=self.config.voc_predictor_conf, ) - - # warm up try: - self.warm_up() - logger.info("Warm up successfully.") + self.executor._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) except Exception as e: - logger.error("Failed to warm up on tts engine.") + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) return False logger.info("Initialize TTS server engine successfully.") return True - def warm_up(self): - """warm up + +class PaddleTTSConnectionHandler(TTSServerExecutor): + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.frontend = self.executor.frontend + self.am_predictor = self.executor.am_predictor + self.voc_predictor = self.executor.voc_predictor def postprocess(self, wav, @@ -375,8 +379,11 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) + sys.exit(-1) # wav to base64 buf = io.BytesIO() @@ -433,7 +440,7 @@ class TTSEngine(BaseEngine): try: infer_st = time.time() - self.executor.infer( + self.infer( text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st @@ -441,13 +448,16 @@ class TTSEngine(BaseEngine): except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), + wav=self._outputs["wav"].numpy(), original_fs=self.executor.am_sample_rate, target_fs=sample_rate, volume=volume, @@ -455,26 +465,28 @@ class TTSEngine(BaseEngine): audio_path=save_path) postprocess_et = time.time() postprocess_time = postprocess_et - postprocess_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_sample_rate + duration = len( + self._outputs["wav"].numpy()) / self.executor.am_sample_rate rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) logger.info("AM model: {}".format(self.config.am)) logger.info("Vocoder model: {}".format(self.config.voc)) logger.info("Language: {}".format(lang)) - logger.info("tts engine type: paddle inference") + logger.info("tts engine type: python") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". @@ -482,5 +494,6 @@ class TTSEngine(BaseEngine): logger.info("total generate audio time: {}".format(infer_time + postprocess_time)) logger.info("RTF: {}".format(rtf)) + logger.info("device: {}".format(self.tts_engine.device)) return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py index d0002baa4f46c949e8258a7bea527a18b781b657..b048b01a49f1cf34a1edd4b10d5b85da74e579f4 100644 --- a/paddlespeech/server/engine/tts/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 import io +import sys import time import librosa @@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): @@ -52,6 +53,8 @@ class TTSEngine(BaseEngine): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() self.config = config + self.lang = self.config.lang + self.engine_type = "python" try: if self.config.device is not None: @@ -59,12 +62,13 @@ class TTSEngine(BaseEngine): else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException as e: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False try: @@ -81,41 +85,35 @@ class TTSEngine(BaseEngine): voc_ckpt=self.config.voc_ckpt, voc_stat=self.config.voc_stat, lang=self.config.lang) - except BaseException: + except Exception as e: logger.error("Failed to get model related files.") logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") + logger.error(e) return False logger.info("Initialize TTS server engine successfully on device: %s." % (self.device)) return True - def warm_up(self): - """warm up + +class PaddleTTSConnectionHandler(TTSServerExecutor): + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.frontend = self.executor.frontend + self.am_inference = self.executor.am_inference + self.voc_inference = self.executor.voc_inference def postprocess(self, wav, @@ -167,8 +165,11 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) + sys.exit(-1) # wav to base64 buf = io.BytesIO() @@ -225,24 +226,27 @@ class TTSEngine(BaseEngine): try: infer_st = time.time() - self.executor.infer( + self.infer( text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_config.fs + duration = len( + self._outputs["wav"].numpy()) / self.executor.am_config.fs rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), + wav=self._outputs["wav"].numpy(), original_fs=self.executor.am_config.fs, target_fs=sample_rate, volume=volume, @@ -254,8 +258,11 @@ class TTSEngine(BaseEngine): except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) logger.info("AM model: {}".format(self.config.am)) logger.info("Vocoder model: {}".format(self.config.voc)) @@ -263,10 +270,9 @@ class TTSEngine(BaseEngine): logger.info("tts engine type: python") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". @@ -274,6 +280,6 @@ class TTSEngine(BaseEngine): logger.info("total generate audio time: {}".format(infer_time + postprocess_time)) logger.info("RTF: {}".format(rtf)) - logger.info("device: {}".format(self.device)) + logger.info("device: {}".format(self.tts_engine.device)) return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py index cf46735dcc84dc92c8bfcfa71b426604ed7c1843..c7bc50ce4b8a166096ac613a19bf9d9cf126ed77 100644 --- a/paddlespeech/server/restful/asr_api.py +++ b/paddlespeech/server/restful/asr_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import sys import traceback from typing import Union from fastapi import APIRouter +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.restful.request import ASRRequest from paddlespeech.server.restful.response import ASRResponse @@ -68,8 +70,18 @@ def asr(request_body: ASRRequest): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] - asr_engine.run(audio_data) - asr_results = asr_engine.postprocess() + if asr_engine.engine_type == "python": + from paddlespeech.server.engine.asr.python.asr_engine import PaddleASRConnectionHandler + elif asr_engine.engine_type == "inference": + from paddlespeech.server.engine.asr.paddleinference.asr_engine import PaddleASRConnectionHandler + else: + logger.error("Offline asr engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleASRConnectionHandler(asr_engine) + + connection_handler.run(audio_data) + asr_results = connection_handler.postprocess() response = { "success": True, diff --git a/paddlespeech/server/restful/cls_api.py b/paddlespeech/server/restful/cls_api.py index 306d9ca9c11ce824cba3982492ea285f6d99a3ff..7cfb4a297f0fb08a9c5f521e0222fa305b331453 100644 --- a/paddlespeech/server/restful/cls_api.py +++ b/paddlespeech/server/restful/cls_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import sys import traceback from typing import Union from fastapi import APIRouter +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.restful.request import CLSRequest from paddlespeech.server.restful.response import CLSResponse @@ -68,8 +70,18 @@ def cls(request_body: CLSRequest): engine_pool = get_engine_pool() cls_engine = engine_pool['cls'] - cls_engine.run(audio_data) - cls_results = cls_engine.postprocess(request_body.topk) + if cls_engine.engine_type == "python": + from paddlespeech.server.engine.cls.python.cls_engine import PaddleCLSConnectionHandler + elif cls_engine.engine_type == "inference": + from paddlespeech.server.engine.cls.paddleinference.cls_engine import PaddleCLSConnectionHandler + else: + logger.error("Offline cls engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleCLSConnectionHandler(cls_engine) + + connection_handler.run(audio_data) + cls_results = connection_handler.postprocess(request_body.topk) response = { "success": True, @@ -85,8 +97,11 @@ def cls(request_body: CLSRequest): except ServerBaseException as e: response = failed_response(e.error_code, e.msg) - except BaseException: + logger.error(e) + sys.exit(-1) + except Exception as e: response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + logger.error(e) traceback.print_exc() return response diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 15d618d9324fcda2616d571a4d074ea0876f0fb5..53fe159fdc02f75f76d438d8ab5876d440fc19c0 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import traceback from typing import Union @@ -99,7 +100,16 @@ def tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - lang, target_sample_rate, duration, wav_base64 = tts_engine.run( + if tts_engine.engine_type == "python": + from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Offline tts engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleTTSConnectionHandler(tts_engine) + lang, target_sample_rate, duration, wav_base64 = connection_handler.run( text, spk_id, speed, volume, sample_rate, save_path) response = { @@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - return StreamingResponse(tts_engine.run(sentence=text)) + if tts_engine.engine_type == "online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Online tts engine only support online or online-onnx.") + sys.exit(-1) + + connection_handler = PaddleTTSConnectionHandler(tts_engine) + + return StreamingResponse(connection_handler.run(sentence=text)) diff --git a/paddlespeech/server/ws/tts_api.py b/paddlespeech/server/ws/tts_api.py index a3a4c4d4b1cbf298f258e5fa064aa425c6f1bdea..3d8b222ead1f8568417e2ba005b04cd2ddd6fbef 100644 --- a/paddlespeech/server/ws/tts_api.py +++ b/paddlespeech/server/ws/tts_api.py @@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() tts_engine = engine_pool['tts'] + connection_handler = None + + if tts_engine.engine_type == "online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Online tts engine only support online or online-onnx.") + sys.exit(-1) + try: while True: # careful here, changed the source code from starlette.websockets @@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket): "signal": "server ready", "session": session } + + connection_handler = PaddleTTSConnectionHandler(tts_engine) await websocket.send_json(resp) # end request elif message['signal'] == 'end': + connection_handler = None resp = { "status": 0, "signal": "connection will be closed", @@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket): # speech synthesis request elif 'text' in message: text_bese64 = message["text"] - sentence = tts_engine.preprocess(text_bese64=text_bese64) + sentence = connection_handler.preprocess( + text_bese64=text_bese64) # run - wav_generator = tts_engine.run(sentence) + wav_generator = connection_handler.run(sentence) while True: try: