提交 d48c4d68 编写于 作者: L lym0302

update engine, test=doc

上级 f07f57a3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import warnings
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
warnings.filterwarnings("ignore")
import sys
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
# change yaml file here
config_file = "./conf/application.yaml"
config = get_config(config_file)
# init engine
if not init_engine_pool(config):
print("Failed to init engine.")
sys.exit(-1)
# get api_router
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
sys.exit(-1)
# app needs to operate outside the main function
app.include_router(api_router)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--workers", type=int, help="workers of server", default=1)
args = parser.parse_args()
uvicorn.run(
"start_multi_progress_server:app",
host=config.host,
port=config.port,
debug=True,
workers=args.workers)
...@@ -26,6 +26,7 @@ from ..util import cli_server_register ...@@ -26,6 +26,7 @@ from ..util import cli_server_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.engine.engine_warmup import warm_up
from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router from paddlespeech.server.ws.api import setup_router as setup_ws_router
...@@ -86,6 +87,11 @@ class ServerExecutor(BaseExecutor): ...@@ -86,6 +87,11 @@ class ServerExecutor(BaseExecutor):
if not init_engine_pool(config): if not init_engine_pool(config):
return False return False
# warm up
for engine_and_type in config.engine_list:
if not warm_up(engine_and_type):
return False
return True return True
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
......
...@@ -30,7 +30,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine ...@@ -30,7 +30,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine'] __all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
...@@ -50,7 +50,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -50,7 +50,7 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
self.max_len = 50
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
...@@ -172,10 +172,23 @@ class ASREngine(BaseEngine): ...@@ -172,10 +172,23 @@ class ASREngine(BaseEngine):
Returns: Returns:
bool: init failed or success bool: init failed or success
""" """
self.input = None
self.output = None
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = config self.config = config
self.engine_type = "inference"
try:
if self.config.am_predictor_conf.device is not None:
self.device = self.config.am_predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
self.executor._init_from_path( self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
...@@ -190,22 +203,42 @@ class ASREngine(BaseEngine): ...@@ -190,22 +203,42 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.config = self.executor.config
self.max_len = self.executor.max_len
self.decoder = self.executor.decoder
self.am_predictor = self.executor.am_predictor
self.text_feature = self.executor.text_feature
self.collate_fn_test = self.executor.collate_fn_test
def run(self, audio_data): def run(self, audio_data):
"""engine run """engine run
Args: Args:
audio_data (bytes): base64.b64decode audio_data (bytes): base64.b64decode
""" """
if self.executor._check( if self._check(
io.BytesIO(audio_data), self.config.sample_rate, io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.config.force_yes): self.asr_engine.config.force_yes):
logger.info("start running asr engine") logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, self.preprocess(self.asr_engine.config.model_type,
io.BytesIO(audio_data)) io.BytesIO(audio_data))
st = time.time() st = time.time()
self.executor.infer(self.config.model_type) self.infer(self.asr_engine.config.model_type)
infer_time = time.time() - st infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr. self.output = self.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine") logger.info("end inferring asr engine")
else: else:
logger.info("file check failed!") logger.info("file check failed!")
...@@ -213,8 +246,3 @@ class ASREngine(BaseEngine): ...@@ -213,8 +246,3 @@ class ASREngine(BaseEngine):
logger.info("inference time: {}".format(infer_time)) logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: paddle inference") logger.info("asr engine type: paddle inference")
def postprocess(self):
"""postprocess
"""
return self.output
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import sys
import time import time
import paddle import paddle
...@@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor ...@@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine'] __all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
...@@ -48,20 +49,23 @@ class ASREngine(BaseEngine): ...@@ -48,20 +49,23 @@ class ASREngine(BaseEngine):
Returns: Returns:
bool: init failed or success bool: init failed or success
""" """
self.input = None
self.output = None
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = config self.config = config
self.engine_type = "python"
try: try:
if self.config.device: if self.config.device is not None:
self.device = self.config.device self.device = self.config.device
else: else:
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException: except Exception as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
) )
logger.error(e)
return False
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate, self.config.model, self.config.lang, self.config.sample_rate,
...@@ -72,6 +76,24 @@ class ASREngine(BaseEngine): ...@@ -72,6 +76,24 @@ class ASREngine(BaseEngine):
(self.device)) (self.device))
return True return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.max_len = self.executor.max_len
self.text_feature = self.executor.text_feature
self.model = self.executor.model
self.config = self.executor.config
def run(self, audio_data): def run(self, audio_data):
"""engine run """engine run
...@@ -79,17 +101,16 @@ class ASREngine(BaseEngine): ...@@ -79,17 +101,16 @@ class ASREngine(BaseEngine):
audio_data (bytes): base64.b64decode audio_data (bytes): base64.b64decode
""" """
try: try:
if self.executor._check( if self._check(
io.BytesIO(audio_data), self.config.sample_rate, io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.config.force_yes): self.asr_engine.config.force_yes):
logger.info("start run asr engine") logger.info("start run asr engine")
self.executor.preprocess(self.config.model, self.preprocess(self.asr_engine.config.model,
io.BytesIO(audio_data)) io.BytesIO(audio_data))
st = time.time() st = time.time()
self.executor.infer(self.config.model) self.infer(self.asr_engine.config.model)
infer_time = time.time() - st infer_time = time.time() - st
self.output = self.executor.postprocess( self.output = self.postprocess() # Retrieve result of asr.
) # Retrieve result of asr.
else: else:
logger.info("file check failed!") logger.info("file check failed!")
self.output = None self.output = None
...@@ -98,8 +119,4 @@ class ASREngine(BaseEngine): ...@@ -98,8 +119,4 @@ class ASREngine(BaseEngine):
logger.info("asr engine type: python") logger.info("asr engine type: python")
except Exception as e: except Exception as e:
logger.info(e) logger.info(e)
sys.exit(-1)
def postprocess(self):
"""postprocess
"""
return self.output
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import io import io
import os import os
import time import time
from collections import OrderedDict
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -27,7 +28,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine ...@@ -27,7 +28,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['CLSEngine'] __all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor): class CLSServerExecutor(CLSExecutor):
...@@ -119,14 +120,55 @@ class CLSEngine(BaseEngine): ...@@ -119,14 +120,55 @@ class CLSEngine(BaseEngine):
""" """
self.executor = CLSServerExecutor() self.executor = CLSServerExecutor()
self.config = config self.config = config
self.executor._init_from_path( self.engine_type = "inference"
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path, try:
self.config.label_file, self.config.predictor_conf) if self.config.predictor_conf.device is not None:
self.device = self.config.predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
try:
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
except Exception as e:
logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False
logger.info("Initialize CLS server engine successfully.") logger.info("Initialize CLS server engine successfully.")
return True return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.predictor = self.executor.predictor
def run(self, audio_data): def run(self, audio_data):
"""engine run """engine run
...@@ -134,9 +176,9 @@ class CLSEngine(BaseEngine): ...@@ -134,9 +176,9 @@ class CLSEngine(BaseEngine):
audio_data (bytes): base64.b64decode audio_data (bytes): base64.b64decode
""" """
self.executor.preprocess(io.BytesIO(audio_data)) self.preprocess(io.BytesIO(audio_data))
st = time.time() st = time.time()
self.executor.infer() self.infer()
infer_time = time.time() - st infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time)) logger.info("inference time: {}".format(infer_time))
...@@ -145,15 +187,15 @@ class CLSEngine(BaseEngine): ...@@ -145,15 +187,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int): def postprocess(self, topk: int):
"""postprocess """postprocess
""" """
assert topk <= len(self.executor._label_list assert topk <= len(
), 'Value of topk is larger than number of labels.' self._label_list), 'Value of topk is larger than number of labels.'
result = np.squeeze(self.executor._outputs['logits'], axis=0) result = np.squeeze(self._outputs['logits'], axis=0)
topk_idx = (-result).argsort()[:topk] topk_idx = (-result).argsort()[:topk]
topk_results = [] topk_results = []
for idx in topk_idx: for idx in topk_idx:
res = {} res = {}
label, score = self.executor._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
res['class_name'] = label res['class_name'] = label
res['prob'] = score res['prob'] = score
topk_results.append(res) topk_results.append(res)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import io import io
import time import time
from typing import List from collections import OrderedDict
import paddle import paddle
...@@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor ...@@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['CLSEngine'] __all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor): class CLSServerExecutor(CLSExecutor):
...@@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor): ...@@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor):
super().__init__() super().__init__()
pass pass
def get_topk_results(self, topk: int) -> List:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
res = {}
topk_results = []
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
res['class'] = label
res['prob'] = score
topk_results.append(res)
return topk_results
class CLSEngine(BaseEngine): class CLSEngine(BaseEngine):
"""CLS server engine """CLS server engine
...@@ -64,42 +49,65 @@ class CLSEngine(BaseEngine): ...@@ -64,42 +49,65 @@ class CLSEngine(BaseEngine):
Returns: Returns:
bool: init failed or success bool: init failed or success
""" """
self.input = None
self.output = None
self.executor = CLSServerExecutor() self.executor = CLSServerExecutor()
self.config = config self.config = config
self.engine_type = "python"
try: try:
if self.config.device: if self.config.device is not None:
self.device = self.config.device self.device = self.config.device
else: else:
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException: except Exception as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
) )
logger.error(e)
return False
try: try:
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.cfg_path, self.config.ckpt_path, self.config.model, self.config.cfg_path, self.config.ckpt_path,
self.config.label_file) self.config.label_file)
except BaseException: except Exception as e:
logger.error("Initialize CLS server engine Failed.") logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False return False
logger.info("Initialize CLS server engine successfully on device: %s." % logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device)) (self.device))
return True return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.model = self.executor.model
def run(self, audio_data): def run(self, audio_data):
"""engine run """engine run
Args: Args:
audio_data (bytes): base64.b64decode audio_data (bytes): base64.b64decode
""" """
self.executor.preprocess(io.BytesIO(audio_data)) self.preprocess(io.BytesIO(audio_data))
st = time.time() st = time.time()
self.executor.infer() self.infer()
infer_time = time.time() - st infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time)) logger.info("inference time: {}".format(infer_time))
...@@ -108,15 +116,15 @@ class CLSEngine(BaseEngine): ...@@ -108,15 +116,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int): def postprocess(self, topk: int):
"""postprocess """postprocess
""" """
assert topk <= len(self.executor._label_list assert topk <= len(
), 'Value of topk is larger than number of labels.' self._label_list), 'Value of topk is larger than number of labels.'
result = self.executor._outputs['logits'].squeeze(0).numpy() result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk] topk_idx = (-result).argsort()[:topk]
topk_results = [] topk_results = []
for idx in topk_idx: for idx in topk_idx:
res = {} res = {}
label, score = self.executor._label_list[idx], result[idx] label, score = self._label_list[idx], result[idx]
res['class_name'] = label res['class_name'] = label
res['prob'] = score res['prob'] = score
topk_results.append(res) topk_results.append(res)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
engine_pool = get_engine_pool()
if "tts" in engine_and_type:
tts_engine = engine_pool['tts']
flag_online = False
if tts_engine.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
elif tts_engine.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
else:
logger.error("tts engine only support lang: zh or en.")
sys.exit(-1)
if engine_and_type == "tts_python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
flag_online = True
elif engine_and_type == "tts_online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
flag_online = True
else:
logger.error("Please check tte engine type.")
try:
logger.info("Start to warm up tts engine.")
for i in range(warm_up_time):
connection_handler = PaddleTTSConnectionHandler(tts_engine)
if flag_online:
for wav in connection_handler.infer(
text=sentence,
lang=tts_engine.lang,
am=tts_engine.config.am):
logger.info(
f"The first response time of the {i} warm up: {connection_handler.first_response_time} s"
)
break
else:
st = time.time()
connection_handler.infer(text=sentence)
et = time.time()
logger.info(
f"The response time of the {i} warm up: {et - st} s")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
else:
pass
return True
...@@ -31,18 +31,12 @@ from paddlespeech.server.utils.util import get_chunks ...@@ -31,18 +31,12 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine'] __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample): def __init__(self):
super().__init__() super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.voc_upsample = voc_upsample
self.pretrained_models = pretrained_models self.pretrained_models = pretrained_models
def _init_from_path( def _init_from_path(
...@@ -161,6 +155,115 @@ class TTSServerExecutor(TTSExecutor): ...@@ -161,6 +155,115 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!") logger.info("frontend done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online-onnx"
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.config.voc_upsample
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
logger(e)
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
""" """
Streaming inference removes the result of pad inference Streaming inference removes the result of pad inference
...@@ -189,12 +292,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -189,12 +292,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_upsample
# first_flag 用于标记首包 # first_flag 用于标记首包
first_flag = 1 first_flag = 1
get_tone_ids = False get_tone_ids = False
...@@ -203,7 +300,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -203,7 +300,7 @@ class TTSServerExecutor(TTSExecutor):
# front # front
frontend_st = time.time() frontend_st = time.time()
if lang == 'zh': if lang == 'zh':
input_ids = self.frontend.get_input_ids( input_ids = self.executor.frontend.get_input_ids(
text, text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids)
...@@ -211,7 +308,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -211,7 +308,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
elif lang == 'en': elif lang == 'en':
input_ids = self.frontend.get_input_ids( input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences) text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
...@@ -226,7 +323,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -226,7 +323,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc # fastspeech2_csmsc
if am == "fastspeech2_csmsc_onnx": if am == "fastspeech2_csmsc_onnx":
# am # am
mel = self.am_sess.run( mel = self.executor.am_sess.run(
output_names=None, input_feed={'text': part_phone_ids}) output_names=None, input_feed={'text': part_phone_ids})
mel = mel[0] mel = mel[0]
if first_flag == 1: if first_flag == 1:
...@@ -234,14 +331,16 @@ class TTSServerExecutor(TTSExecutor): ...@@ -234,14 +331,16 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
# voc streaming # voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks) voc_chunk_num = len(mel_chunks)
voc_st = time.time() voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks): for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_sess.run( sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel_chunk}) output_names=None, input_feed={'logmel': mel_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
voc_block, voc_pad, voc_upsample) self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
...@@ -253,7 +352,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -253,7 +352,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc # fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc_onnx": elif am == "fastspeech2_cnndecoder_csmsc_onnx":
# am # am
orig_hs = self.am_encoder_infer_sess.run( orig_hs = self.executor.am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids}) None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0] orig_hs = orig_hs[0]
...@@ -267,9 +366,9 @@ class TTSServerExecutor(TTSExecutor): ...@@ -267,9 +366,9 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss) am_chunk_num = len(hss)
for i, hs in enumerate(hss): for i, hs in enumerate(hss):
am_decoder_output = self.am_decoder_sess.run( am_decoder_output = self.executor.am_decoder_sess.run(
None, input_feed={'xs': hs}) None, input_feed={'xs': hs})
am_postnet_output = self.am_postnet_sess.run( am_postnet_output = self.executor.am_postnet_sess.run(
None, None,
input_feed={ input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) 'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
...@@ -278,9 +377,11 @@ class TTSServerExecutor(TTSExecutor): ...@@ -278,9 +377,11 @@ class TTSServerExecutor(TTSExecutor):
am_postnet_output[0], (0, 2, 1)) am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0] normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) sub_mel = denorm(normalized_mel, self.executor.am_mu,
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, self.executor.am_std)
am_pad, am_upsample) sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0: if i == 0:
mel_streaming = sub_mel mel_streaming = sub_mel
...@@ -297,11 +398,11 @@ class TTSServerExecutor(TTSExecutor): ...@@ -297,11 +398,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :] voc_chunk = mel_streaming[start:end, :]
sub_wav = self.voc_sess.run( sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': voc_chunk}) output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, sub_wav = self.depadding(
voc_chunk_id, voc_block, sub_wav[0], voc_chunk_num, voc_chunk_id,
voc_pad, voc_upsample) self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
...@@ -311,9 +412,11 @@ class TTSServerExecutor(TTSExecutor): ...@@ -311,9 +412,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav yield sub_wav
voc_chunk_id += 1 voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad) start = max(
end = min((voc_chunk_id + 1) * voc_block + voc_pad, 0, voc_chunk_id * self.voc_block - self.voc_pad)
mel_len) end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else: else:
logger.error( logger.error(
...@@ -322,111 +425,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -322,111 +425,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
self.executor = TTSServerExecutor(
self.config.am_block, self.config.am_pad, self.config.voc_block,
self.config.voc_pad, self.config.voc_upsample)
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text # Convert byte to text
if text_bese64: if text_bese64:
...@@ -459,7 +457,7 @@ class TTSEngine(BaseEngine): ...@@ -459,7 +457,7 @@ class TTSEngine(BaseEngine):
""" """
wav_list = [] wav_list = []
for wav in self.executor.infer( for wav in self.infer(
text=sentence, text=sentence,
lang=self.config.lang, lang=self.config.lang,
am=self.config.am, am=self.config.am,
...@@ -477,11 +475,9 @@ class TTSEngine(BaseEngine): ...@@ -477,11 +475,9 @@ class TTSEngine(BaseEngine):
duration = len(wav_all) / self.config.voc_sample_rate duration = len(wav_all) / self.config.voc_sample_rate
logger.info(f"sentence: {sentence}") logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s") logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info( logger.info(
f"first response time: {self.executor.first_response_time} s") f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
) )
...@@ -34,16 +34,12 @@ from paddlespeech.t2s.frontend.zh_frontend import Frontend ...@@ -34,16 +34,12 @@ from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSEngine'] __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad): def __init__(self):
super().__init__() super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.pretrained_models = pretrained_models self.pretrained_models = pretrained_models
def get_model_info(self, def get_model_info(self,
...@@ -205,6 +201,106 @@ class TTSServerExecutor(TTSExecutor): ...@@ -205,6 +201,106 @@ class TTSServerExecutor(TTSExecutor):
self.voc_inference.eval() self.voc_inference.eval()
print("voc done!") print("voc done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online"
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.executor.voc_config.n_shift
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
""" """
Streaming inference removes the result of pad inference Streaming inference removes the result of pad inference
...@@ -233,12 +329,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -233,12 +329,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包 # first_flag 用于标记首包
first_flag = 1 first_flag = 1
...@@ -246,7 +336,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -246,7 +336,7 @@ class TTSServerExecutor(TTSExecutor):
merge_sentences = False merge_sentences = False
frontend_st = time.time() frontend_st = time.time()
if lang == 'zh': if lang == 'zh':
input_ids = self.frontend.get_input_ids( input_ids = self.executor.frontend.get_input_ids(
text, text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids)
...@@ -254,7 +344,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -254,7 +344,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
elif lang == 'en': elif lang == 'en':
input_ids = self.frontend.get_input_ids( input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences) text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
...@@ -269,19 +359,21 @@ class TTSServerExecutor(TTSExecutor): ...@@ -269,19 +359,21 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc # fastspeech2_csmsc
if am == "fastspeech2_csmsc": if am == "fastspeech2_csmsc":
# am # am
mel = self.am_inference(part_phone_ids) mel = self.executor.am_inference(part_phone_ids)
if first_flag == 1: if first_flag == 1:
first_am_et = time.time() first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
# voc streaming # voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks) voc_chunk_num = len(mel_chunks)
voc_st = time.time() voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks): for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk) sub_wav = self.executor.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i, sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample) self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
...@@ -293,7 +385,8 @@ class TTSServerExecutor(TTSExecutor): ...@@ -293,7 +385,8 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc # fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc": elif am == "fastspeech2_cnndecoder_csmsc":
# am # am
orig_hs = self.am_inference.encoder_infer(part_phone_ids) orig_hs = self.executor.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info # streaming voc chunk info
mel_len = orig_hs.shape[1] mel_len = orig_hs.shape[1]
...@@ -305,13 +398,15 @@ class TTSServerExecutor(TTSExecutor): ...@@ -305,13 +398,15 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss) am_chunk_num = len(hss)
for i, hs in enumerate(hss): for i, hs in enumerate(hss):
before_outs = self.am_inference.decoder(hs) before_outs = self.executor.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet( after_outs = before_outs + self.executor.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0] normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) sub_mel = denorm(normalized_mel, self.executor.am_mu,
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, self.executor.am_std)
am_pad, am_upsample) sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0: if i == 0:
mel_streaming = sub_mel mel_streaming = sub_mel
...@@ -328,11 +423,11 @@ class TTSServerExecutor(TTSExecutor): ...@@ -328,11 +423,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :] voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk) voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk) sub_wav = self.executor.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, sub_wav = self.depadding(
voc_chunk_id, voc_block, sub_wav, voc_chunk_num, voc_chunk_id,
voc_pad, voc_upsample) self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
...@@ -342,9 +437,11 @@ class TTSServerExecutor(TTSExecutor): ...@@ -342,9 +437,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav yield sub_wav
voc_chunk_id += 1 voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad) start = max(
end = min((voc_chunk_id + 1) * voc_block + voc_pad, 0, voc_chunk_id * self.voc_block - self.voc_pad)
mel_len) end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else: else:
logger.error( logger.error(
...@@ -353,100 +450,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -353,100 +450,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text # Convert byte to text
if text_bese64: if text_bese64:
...@@ -480,7 +483,7 @@ class TTSEngine(BaseEngine): ...@@ -480,7 +483,7 @@ class TTSEngine(BaseEngine):
wav_list = [] wav_list = []
for wav in self.executor.infer( for wav in self.infer(
text=sentence, text=sentence,
lang=self.config.lang, lang=self.config.lang,
am=self.config.am, am=self.config.am,
...@@ -496,13 +499,12 @@ class TTSEngine(BaseEngine): ...@@ -496,13 +499,12 @@ class TTSEngine(BaseEngine):
wav_all = np.concatenate(wav_list, axis=0) wav_all = np.concatenate(wav_list, axis=0)
duration = len(wav_all) / self.executor.am_config.fs duration = len(wav_all) / self.executor.am_config.fs
logger.info(f"sentence: {sentence}") logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s") logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info( logger.info(
f"first response time: {self.executor.first_response_time} s") f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
logger.info( )
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import base64 import base64
import io import io
import os import os
import sys
import time import time
from typing import Optional from typing import Optional
...@@ -35,7 +36,7 @@ from paddlespeech.server.utils.paddle_predictor import run_model ...@@ -35,7 +36,7 @@ from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine'] __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
...@@ -245,7 +246,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -245,7 +246,7 @@ class TTSServerExecutor(TTSExecutor):
else: else:
wav_all = paddle.concat([wav_all, wav]) wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st) self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all self._outputs["wav"] = wav_all
class TTSEngine(BaseEngine): class TTSEngine(BaseEngine):
...@@ -263,6 +264,8 @@ class TTSEngine(BaseEngine): ...@@ -263,6 +264,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor() self.executor = TTSServerExecutor()
self.config = config self.config = config
self.lang = self.config.lang
self.engine_type = "inference"
try: try:
if self.config.am_predictor_conf.device is not None: if self.config.am_predictor_conf.device is not None:
...@@ -272,58 +275,59 @@ class TTSEngine(BaseEngine): ...@@ -272,58 +275,59 @@ class TTSEngine(BaseEngine):
else: else:
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException as e: except Exception as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
) )
logger.error("Initialize TTS server engine Failed on device: %s." % logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device)) (self.device))
logger.error(e)
return False return False
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
# warm up
try: try:
self.warm_up() self.executor._init_from_path(
logger.info("Warm up successfully.") am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except Exception as e: except Exception as e:
logger.error("Failed to warm up on tts engine.") logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
return True return True
def warm_up(self):
"""warm up class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
""" """
if self.config.lang == 'zh': super().__init__()
sentence = "您好,欢迎使用语音合成服务。" logger.info(
if self.config.lang == 'en': "Create PaddleTTSConnectionHandler to process the tts request")
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.") self.tts_engine = tts_engine
for i in range(3): self.executor = self.tts_engine.executor
st = time.time() self.config = self.tts_engine.config
self.executor.infer( self.frontend = self.executor.frontend
text=sentence, self.am_predictor = self.executor.am_predictor
lang=self.config.lang, self.voc_predictor = self.executor.voc_predictor
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
def postprocess(self, def postprocess(self,
wav, wav,
...@@ -375,8 +379,11 @@ class TTSEngine(BaseEngine): ...@@ -375,8 +379,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \ "Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.") logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
...@@ -433,7 +440,7 @@ class TTSEngine(BaseEngine): ...@@ -433,7 +440,7 @@ class TTSEngine(BaseEngine):
try: try:
infer_st = time.time() infer_st = time.time()
self.executor.infer( self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time() infer_et = time.time()
infer_time = infer_et - infer_st infer_time = infer_et - infer_st
...@@ -441,13 +448,16 @@ class TTSEngine(BaseEngine): ...@@ -441,13 +448,16 @@ class TTSEngine(BaseEngine):
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.") logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try: try:
postprocess_st = time.time() postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(), wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_sample_rate, original_fs=self.executor.am_sample_rate,
target_fs=sample_rate, target_fs=sample_rate,
volume=volume, volume=volume,
...@@ -455,26 +465,28 @@ class TTSEngine(BaseEngine): ...@@ -455,26 +465,28 @@ class TTSEngine(BaseEngine):
audio_path=save_path) audio_path=save_path)
postprocess_et = time.time() postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st postprocess_time = postprocess_et - postprocess_st
duration = len(self.executor._outputs['wav'] duration = len(
.numpy()) / self.executor.am_sample_rate self._outputs["wav"].numpy()) / self.executor.am_sample_rate
rtf = infer_time / duration rtf = infer_time / duration
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.") logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am)) logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc)) logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang)) logger.info("Language: {}".format(lang))
logger.info("tts engine type: paddle inference") logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration)) logger.info("audio duration: {}".format(duration))
logger.info( logger.info("frontend inference time: {}".format(self.frontend_time))
"frontend inference time: {}".format(self.executor.frontend_time)) logger.info("AM inference time: {}".format(self.am_time))
logger.info("AM inference time: {}".format(self.executor.am_time)) logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("total inference time: {}".format(infer_time)) logger.info("total inference time: {}".format(infer_time))
logger.info( logger.info(
"postprocess (change speed, volume, target sample rate) time: {}". "postprocess (change speed, volume, target sample rate) time: {}".
...@@ -482,5 +494,6 @@ class TTSEngine(BaseEngine): ...@@ -482,5 +494,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time + logger.info("total generate audio time: {}".format(infer_time +
postprocess_time)) postprocess_time))
logger.info("RTF: {}".format(rtf)) logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64 return lang, target_sample_rate, duration, wav_base64
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import base64 import base64
import io import io
import sys
import time import time
import librosa import librosa
...@@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed ...@@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException from paddlespeech.server.utils.exception import ServerBaseException
__all__ = ['TTSEngine'] __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
...@@ -52,6 +53,8 @@ class TTSEngine(BaseEngine): ...@@ -52,6 +53,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor() self.executor = TTSServerExecutor()
self.config = config self.config = config
self.lang = self.config.lang
self.engine_type = "python"
try: try:
if self.config.device is not None: if self.config.device is not None:
...@@ -59,12 +62,13 @@ class TTSEngine(BaseEngine): ...@@ -59,12 +62,13 @@ class TTSEngine(BaseEngine):
else: else:
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException as e: except Exception as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
) )
logger.error("Initialize TTS server engine Failed on device: %s." % logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device)) (self.device))
logger.error(e)
return False return False
try: try:
...@@ -81,41 +85,35 @@ class TTSEngine(BaseEngine): ...@@ -81,41 +85,35 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt, voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat, voc_stat=self.config.voc_stat,
lang=self.config.lang) lang=self.config.lang)
except BaseException: except Exception as e:
logger.error("Failed to get model related files.") logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." % logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device)) (self.device))
return False logger.error(e)
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False return False
logger.info("Initialize TTS server engine successfully on device: %s." % logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device)) (self.device))
return True return True
def warm_up(self):
"""warm up class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
""" """
if self.config.lang == 'zh': super().__init__()
sentence = "您好,欢迎使用语音合成服务。" logger.info(
if self.config.lang == 'en': "Create PaddleTTSConnectionHandler to process the tts request")
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.") self.tts_engine = tts_engine
for i in range(3): self.executor = self.tts_engine.executor
st = time.time() self.config = self.tts_engine.config
self.executor.infer( self.frontend = self.executor.frontend
text=sentence, self.am_inference = self.executor.am_inference
lang=self.config.lang, self.voc_inference = self.executor.voc_inference
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
def postprocess(self, def postprocess(self,
wav, wav,
...@@ -167,8 +165,11 @@ class TTSEngine(BaseEngine): ...@@ -167,8 +165,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \ "Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.") logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
...@@ -225,24 +226,27 @@ class TTSEngine(BaseEngine): ...@@ -225,24 +226,27 @@ class TTSEngine(BaseEngine):
try: try:
infer_st = time.time() infer_st = time.time()
self.executor.infer( self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time() infer_et = time.time()
infer_time = infer_et - infer_st infer_time = infer_et - infer_st
duration = len(self.executor._outputs['wav'] duration = len(
.numpy()) / self.executor.am_config.fs self._outputs["wav"].numpy()) / self.executor.am_config.fs
rtf = infer_time / duration rtf = infer_time / duration
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.") logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try: try:
postprocess_st = time.time() postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(), wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_config.fs, original_fs=self.executor.am_config.fs,
target_fs=sample_rate, target_fs=sample_rate,
volume=volume, volume=volume,
...@@ -254,8 +258,11 @@ class TTSEngine(BaseEngine): ...@@ -254,8 +258,11 @@ class TTSEngine(BaseEngine):
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException: sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.") logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am)) logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc)) logger.info("Vocoder model: {}".format(self.config.voc))
...@@ -263,10 +270,9 @@ class TTSEngine(BaseEngine): ...@@ -263,10 +270,9 @@ class TTSEngine(BaseEngine):
logger.info("tts engine type: python") logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration)) logger.info("audio duration: {}".format(duration))
logger.info( logger.info("frontend inference time: {}".format(self.frontend_time))
"frontend inference time: {}".format(self.executor.frontend_time)) logger.info("AM inference time: {}".format(self.am_time))
logger.info("AM inference time: {}".format(self.executor.am_time)) logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("total inference time: {}".format(infer_time)) logger.info("total inference time: {}".format(infer_time))
logger.info( logger.info(
"postprocess (change speed, volume, target sample rate) time: {}". "postprocess (change speed, volume, target sample rate) time: {}".
...@@ -274,6 +280,6 @@ class TTSEngine(BaseEngine): ...@@ -274,6 +280,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time + logger.info("total generate audio time: {}".format(infer_time +
postprocess_time)) postprocess_time))
logger.info("RTF: {}".format(rtf)) logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device)) logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64 return lang, target_sample_rate, duration, wav_base64
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64 import base64
import sys
import traceback import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import ASRRequest from paddlespeech.server.restful.request import ASRRequest
from paddlespeech.server.restful.response import ASRResponse from paddlespeech.server.restful.response import ASRResponse
...@@ -68,8 +70,18 @@ def asr(request_body: ASRRequest): ...@@ -68,8 +70,18 @@ def asr(request_body: ASRRequest):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_engine.run(audio_data) if asr_engine.engine_type == "python":
asr_results = asr_engine.postprocess() from paddlespeech.server.engine.asr.python.asr_engine import PaddleASRConnectionHandler
elif asr_engine.engine_type == "inference":
from paddlespeech.server.engine.asr.paddleinference.asr_engine import PaddleASRConnectionHandler
else:
logger.error("Offline asr engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleASRConnectionHandler(asr_engine)
connection_handler.run(audio_data)
asr_results = connection_handler.postprocess()
response = { response = {
"success": True, "success": True,
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64 import base64
import sys
import traceback import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import CLSRequest from paddlespeech.server.restful.request import CLSRequest
from paddlespeech.server.restful.response import CLSResponse from paddlespeech.server.restful.response import CLSResponse
...@@ -68,8 +70,18 @@ def cls(request_body: CLSRequest): ...@@ -68,8 +70,18 @@ def cls(request_body: CLSRequest):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
cls_engine = engine_pool['cls'] cls_engine = engine_pool['cls']
cls_engine.run(audio_data) if cls_engine.engine_type == "python":
cls_results = cls_engine.postprocess(request_body.topk) from paddlespeech.server.engine.cls.python.cls_engine import PaddleCLSConnectionHandler
elif cls_engine.engine_type == "inference":
from paddlespeech.server.engine.cls.paddleinference.cls_engine import PaddleCLSConnectionHandler
else:
logger.error("Offline cls engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleCLSConnectionHandler(cls_engine)
connection_handler.run(audio_data)
cls_results = connection_handler.postprocess(request_body.topk)
response = { response = {
"success": True, "success": True,
...@@ -85,8 +97,11 @@ def cls(request_body: CLSRequest): ...@@ -85,8 +97,11 @@ def cls(request_body: CLSRequest):
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except BaseException: logger.error(e)
sys.exit(-1)
except Exception as e:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
logger.error(e)
traceback.print_exc() traceback.print_exc()
return response return response
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import traceback import traceback
from typing import Union from typing import Union
...@@ -99,7 +100,16 @@ def tts(request_body: TTSRequest): ...@@ -99,7 +100,16 @@ def tts(request_body: TTSRequest):
tts_engine = engine_pool['tts'] tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.") logger.info("Get tts engine successfully.")
lang, target_sample_rate, duration, wav_base64 = tts_engine.run( if tts_engine.engine_type == "python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Offline tts engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
lang, target_sample_rate, duration, wav_base64 = connection_handler.run(
text, spk_id, speed, volume, sample_rate, save_path) text, spk_id, speed, volume, sample_rate, save_path)
response = { response = {
...@@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest): ...@@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest):
tts_engine = engine_pool['tts'] tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.") logger.info("Get tts engine successfully.")
return StreamingResponse(tts_engine.run(sentence=text)) if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
return StreamingResponse(connection_handler.run(sentence=text))
...@@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
tts_engine = engine_pool['tts'] tts_engine = engine_pool['tts']
connection_handler = None
if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
try: try:
while True: while True:
# careful here, changed the source code from starlette.websockets # careful here, changed the source code from starlette.websockets
...@@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket):
"signal": "server ready", "signal": "server ready",
"session": session "session": session
} }
connection_handler = PaddleTTSConnectionHandler(tts_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
# end request # end request
elif message['signal'] == 'end': elif message['signal'] == 'end':
connection_handler = None
resp = { resp = {
"status": 0, "status": 0,
"signal": "connection will be closed", "signal": "connection will be closed",
...@@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket):
# speech synthesis request # speech synthesis request
elif 'text' in message: elif 'text' in message:
text_bese64 = message["text"] text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64) sentence = connection_handler.preprocess(
text_bese64=text_bese64)
# run # run
wav_generator = tts_engine.run(sentence) wav_generator = connection_handler.run(sentence)
while True: while True:
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册