提交 2296e015 编写于 作者: Y Yang Zhou

Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into refactor_file_struct

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import warnings
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
warnings.filterwarnings("ignore")
import sys
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
# change yaml file here
config_file = "./conf/application.yaml"
config = get_config(config_file)
# init engine
if not init_engine_pool(config):
print("Failed to init engine.")
sys.exit(-1)
# get api_router
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
sys.exit(-1)
# app needs to operate outside the main function
app.include_router(api_router)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--workers", type=int, help="workers of server", default=1)
args = parser.parse_args()
uvicorn.run(
"start_multi_progress_server:app",
host=config.host,
port=config.port,
debug=True,
workers=args.workers)
......@@ -27,7 +27,7 @@ The configuration file can be found in `conf/tts_online_application.yaml`.
- In streaming voc inference, one chunk of data is inferred at a time to achieve a streaming effect. Where `voc_block` indicates the number of valid frames in the chunk, and `voc_pad` indicates the number of frames added before and after the voc_block in a chunk. The existence of voc_pad is used to eliminate errors caused by streaming inference and avoid the influence of streaming inference on the quality of synthesized audio.
- Both hifigan and mb_melgan support streaming voc inference.
- When the voc model is mb_melgan, when voc_pad=14, the synthetic audio for streaming inference is consistent with the non-streaming synthetic audio; the minimum voc_pad can be set to 7, and the synthetic audio has no abnormal hearing. If the voc_pad is less than 7, the synthetic audio sounds abnormal.
- When the voc model is hifigan, when voc_pad=20, the streaming inference synthetic audio is consistent with the non-streaming synthetic audio; when voc_pad=14, the synthetic audio has no abnormal hearing.
- When the voc model is hifigan, when voc_pad=19, the streaming inference synthetic audio is consistent with the non-streaming synthetic audio; when voc_pad=14, the synthetic audio has no abnormal hearing.
- Inference speed: mb_melgan > hifigan; Audio quality: mb_melgan < hifigan
- **Note:** If the service can be started normally in the container, but the client access IP is unreachable, you can try to replace the `host` address in the configuration file with the local IP address.
......
......@@ -27,7 +27,7 @@
- 流式 voc 推理中,每次会对一个 chunk 的数据进行推理以达到流式的效果。其中 `voc_block` 表示chunk中的有效帧数,`voc_pad` 表示一个 chunk 中 voc_block 前后各加的帧数。voc_pad 的存在用于消除流式推理产生的误差,避免由流式推理对合成音频质量的影响。
- hifigan, mb_melgan 均支持流式 voc 推理
- 当 voc 模型为 mb_melgan,当 voc_pad=14 时,流式推理合成音频与非流式合成音频一致;voc_pad 最小可以设置为7,合成音频听感上没有异常,若 voc_pad 小于7,合成音频听感上存在异常。
- 当 voc 模型为 hifigan,当 voc_pad=20 时,流式推理合成音频与非流式合成音频一致;当 voc_pad=14 时,合成音频听感上没有异常。
- 当 voc 模型为 hifigan,当 voc_pad=19 时,流式推理合成音频与非流式合成音频一致;当 voc_pad=14 时,合成音频听感上没有异常。
- 推理速度:mb_melgan > hifigan; 音频质量:mb_melgan < hifigan
- **注意:** 如果在容器里可正常启动服务,但客户端访问 ip 不可达,可尝试将配置文件中 `host` 地址换成本地 ip 地址。
......
......@@ -47,7 +47,7 @@ tts_online:
am_pad: 12
# voc_pad and voc_block voc model to streaming voc infer,
# when voc model is mb_melgan_csmsc, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
voc_block: 36
voc_pad: 14
......@@ -95,7 +95,7 @@ tts_online-onnx:
am_pad: 12
# voc_pad and voc_block voc model to streaming voc infer,
# when voc model is mb_melgan_csmsc_onnx, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc_onnx, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc_onnx, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
voc_block: 36
voc_pad: 14
# voc_upsample should be same as n_shift on voc config.
......
......@@ -27,6 +27,7 @@ from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.engine.engine_warmup import warm_up
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
......@@ -87,6 +88,11 @@ class ServerExecutor(BaseExecutor):
if not init_engine_pool(config):
return False
# warm up
for engine_and_type in config.engine_list:
if not warm_up(engine_and_type):
return False
return True
def execute(self, argv: List[str]) -> bool:
......
......@@ -47,7 +47,7 @@ tts_online:
am_pad: 12
# voc_pad and voc_block voc model to streaming voc infer,
# when voc model is mb_melgan_csmsc, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
voc_block: 36
voc_pad: 14
......@@ -95,7 +95,7 @@ tts_online-onnx:
am_pad: 12
# voc_pad and voc_block voc model to streaming voc infer,
# when voc model is mb_melgan_csmsc_onnx, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc_onnx, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
# when voc model is hifigan_csmsc_onnx, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal
voc_block: 36
voc_pad: 14
# voc_upsample should be same as n_shift on voc config.
......
......@@ -30,7 +30,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['ASREngine']
__all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor):
......@@ -51,7 +51,7 @@ class ASRServerExecutor(ASRExecutor):
"""
Init model and other resources from a specific path.
"""
self.max_len = 50
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.max_len = 50
......@@ -87,7 +87,8 @@ class ASRServerExecutor(ASRExecutor):
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab,
unit_type=self.config.unit_type,
vocab=self.vocab,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
......@@ -176,10 +177,23 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = config
self.engine_type = "inference"
try:
if self.config.am_predictor_conf.device is not None:
self.device = self.config.am_predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
self.executor._init_from_path(
model_type=self.config.model_type,
......@@ -194,22 +208,41 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.")
return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.config = self.executor.config
self.max_len = self.executor.max_len
self.decoder = self.executor.decoder
self.am_predictor = self.executor.am_predictor
self.text_feature = self.executor.text_feature
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.preprocess(self.asr_engine.config.model_type,
io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model_type)
self.infer(self.asr_engine.config.model_type)
infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr.
self.output = self.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
else:
logger.info("file check failed!")
......@@ -217,8 +250,3 @@ class ASREngine(BaseEngine):
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: paddle inference")
def postprocess(self):
"""postprocess
"""
return self.output
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import sys
import time
import paddle
......@@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
__all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor):
......@@ -48,20 +49,23 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = config
self.engine_type = "python"
try:
if self.config.device:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate,
......@@ -72,6 +76,24 @@ class ASREngine(BaseEngine):
(self.device))
return True
class PaddleASRConnectionHandler(ASRServerExecutor):
def __init__(self, asr_engine):
"""The PaddleSpeech ASR Server Connection Handler
This connection process every asr server request
Args:
asr_engine (ASREngine): The ASR engine
"""
super().__init__()
self.input = None
self.output = None
self.asr_engine = asr_engine
self.executor = self.asr_engine.executor
self.max_len = self.executor.max_len
self.text_feature = self.executor.text_feature
self.model = self.executor.model
self.config = self.executor.config
def run(self, audio_data):
"""engine run
......@@ -79,17 +101,16 @@ class ASREngine(BaseEngine):
audio_data (bytes): base64.b64decode
"""
try:
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes):
logger.info("start run asr engine")
self.executor.preprocess(self.config.model,
io.BytesIO(audio_data))
self.preprocess(self.asr_engine.config.model,
io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model)
self.infer(self.asr_engine.config.model)
infer_time = time.time() - st
self.output = self.executor.postprocess(
) # Retrieve result of asr.
self.output = self.postprocess() # Retrieve result of asr.
else:
logger.info("file check failed!")
self.output = None
......@@ -98,8 +119,4 @@ class ASREngine(BaseEngine):
logger.info("asr engine type: python")
except Exception as e:
logger.info(e)
def postprocess(self):
"""postprocess
"""
return self.output
sys.exit(-1)
......@@ -14,6 +14,7 @@
import io
import os
import time
from collections import OrderedDict
from typing import Optional
import numpy as np
......@@ -27,7 +28,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
__all__ = ['CLSEngine']
__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor):
......@@ -121,14 +122,55 @@ class CLSEngine(BaseEngine):
"""
self.executor = CLSServerExecutor()
self.config = config
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
self.engine_type = "inference"
try:
if self.config.predictor_conf.device is not None:
self.device = self.config.predictor_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
try:
self.executor._init_from_path(
self.config.model_type, self.config.cfg_path,
self.config.model_path, self.config.params_path,
self.config.label_file, self.config.predictor_conf)
except Exception as e:
logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False
logger.info("Initialize CLS server engine successfully.")
return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.predictor = self.executor.predictor
def run(self, audio_data):
"""engine run
......@@ -136,9 +178,9 @@ class CLSEngine(BaseEngine):
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
self.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
self.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
......@@ -147,15 +189,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = np.squeeze(self.executor._outputs['logits'], axis=0)
result = np.squeeze(self._outputs['logits'], axis=0)
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
label, score = self._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import io
import time
from typing import List
from collections import OrderedDict
import paddle
......@@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['CLSEngine']
__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor):
......@@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor):
super().__init__()
pass
def get_topk_results(self, topk: int) -> List:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
res = {}
topk_results = []
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
res['class'] = label
res['prob'] = score
topk_results.append(res)
return topk_results
class CLSEngine(BaseEngine):
"""CLS server engine
......@@ -64,42 +49,65 @@ class CLSEngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = CLSServerExecutor()
self.config = config
self.engine_type = "python"
try:
if self.config.device:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error(e)
return False
try:
self.executor._init_from_path(
self.config.model, self.config.cfg_path, self.config.ckpt_path,
self.config.label_file)
except BaseException:
except Exception as e:
logger.error("Initialize CLS server engine Failed.")
logger.error(e)
return False
logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device))
return True
class PaddleCLSConnectionHandler(CLSServerExecutor):
def __init__(self, cls_engine):
"""The PaddleSpeech CLS Server Connection Handler
This connection process every cls server request
Args:
cls_engine (CLSEngine): The CLS engine
"""
super().__init__()
logger.info(
"Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.cls_engine = cls_engine
self.executor = self.cls_engine.executor
self._conf = self.executor._conf
self._label_list = self.executor._label_list
self.model = self.executor.model
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
self.executor.preprocess(io.BytesIO(audio_data))
self.preprocess(io.BytesIO(audio_data))
st = time.time()
self.executor.infer()
self.infer()
infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time))
......@@ -108,15 +116,15 @@ class CLSEngine(BaseEngine):
def postprocess(self, topk: int):
"""postprocess
"""
assert topk <= len(self.executor._label_list
), 'Value of topk is larger than number of labels.'
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
result = self.executor._outputs['logits'].squeeze(0).numpy()
result = self._outputs['logits'].squeeze(0).numpy()
topk_idx = (-result).argsort()[:topk]
topk_results = []
for idx in topk_idx:
res = {}
label, score = self.executor._label_list[idx], result[idx]
label, score = self._label_list[idx], result[idx]
res['class_name'] = label
res['prob'] = score
topk_results.append(res)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
engine_pool = get_engine_pool()
if "tts" in engine_and_type:
tts_engine = engine_pool['tts']
flag_online = False
if tts_engine.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
elif tts_engine.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
else:
logger.error("tts engine only support lang: zh or en.")
sys.exit(-1)
if engine_and_type == "tts_python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
elif engine_and_type == "tts_online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
flag_online = True
elif engine_and_type == "tts_online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
flag_online = True
else:
logger.error("Please check tte engine type.")
try:
logger.info("Start to warm up tts engine.")
for i in range(warm_up_time):
connection_handler = PaddleTTSConnectionHandler(tts_engine)
if flag_online:
for wav in connection_handler.infer(
text=sentence,
lang=tts_engine.lang,
am=tts_engine.config.am):
logger.info(
f"The first response time of the {i} warm up: {connection_handler.first_response_time} s"
)
break
else:
st = time.time()
connection_handler.infer(text=sentence)
et = time.time()
logger.info(
f"The response time of the {i} warm up: {et - st} s")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
else:
pass
return True
......@@ -31,18 +31,12 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample):
def __init__(self):
super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.voc_upsample = voc_upsample
self.task_resource = CommonTaskResource(task='tts', model_format='onnx')
def _init_from_path(
......@@ -170,6 +164,115 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online-onnx"
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.config.voc_upsample
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
logger(e)
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
......@@ -198,12 +301,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output.
"""
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_upsample
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False
......@@ -212,7 +309,7 @@ class TTSServerExecutor(TTSExecutor):
# front
frontend_st = time.time()
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
......@@ -220,7 +317,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
......@@ -235,7 +332,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc
if am == "fastspeech2_csmsc_onnx":
# am
mel = self.am_sess.run(
mel = self.executor.am_sess.run(
output_names=None, input_feed={'text': part_phone_ids})
mel = mel[0]
if first_flag == 1:
......@@ -243,14 +340,16 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_sess.run(
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
......@@ -262,7 +361,7 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc_onnx":
# am
orig_hs = self.am_encoder_infer_sess.run(
orig_hs = self.executor.am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
......@@ -276,9 +375,9 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = self.am_decoder_sess.run(
am_decoder_output = self.executor.am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = self.am_postnet_sess.run(
am_postnet_output = self.executor.am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
......@@ -287,9 +386,11 @@ class TTSServerExecutor(TTSExecutor):
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0:
mel_streaming = sub_mel
......@@ -306,11 +407,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
sub_wav = self.voc_sess.run(
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
sub_wav = self.depadding(
sub_wav[0], voc_chunk_num, voc_chunk_id,
self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
......@@ -320,9 +421,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
start = max(
0, voc_chunk_id * self.voc_block - self.voc_pad)
end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else:
logger.error(
......@@ -331,111 +434,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
self.config.am == "fastspeech2_csmsc_onnx" or
self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
) and (
self.config.voc == "hifigan_csmsc_onnx" or
self.config.voc == "mb_melgan_csmsc_onnx"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
self.config.voc_block > 0 and self.config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
assert (
self.config.voc_sample_rate == self.config.am_sample_rate
), "The sample rate of AM and Vocoder model are different, please check model."
self.executor = TTSServerExecutor(
self.config.am_block, self.config.am_pad, self.config.voc_block,
self.config.voc_pad, self.config.voc_upsample)
try:
if self.config.am_sess_conf.device is not None:
self.device = self.config.am_sess_conf.device
elif self.config.voc_sess_conf.device is not None:
self.device = self.config.voc_sess_conf.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
am_sample_rate=self.config.am_sample_rate,
am_sess_conf=self.config.am_sess_conf,
voc=self.config.voc,
voc_ckpt=self.config.voc_ckpt,
voc_sample_rate=self.config.voc_sample_rate,
voc_sess_conf=self.config.voc_sess_conf,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.config.voc_sess_conf.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.config.voc_sess_conf.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
......@@ -468,7 +466,7 @@ class TTSEngine(BaseEngine):
"""
wav_list = []
for wav in self.executor.infer(
for wav in self.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
......@@ -486,11 +484,9 @@ class TTSEngine(BaseEngine):
duration = len(wav_all) / self.config.voc_sample_rate
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
)
......@@ -33,19 +33,16 @@ from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
def __init__(self, am_block, am_pad, voc_block, voc_pad):
def __init__(self):
super().__init__()
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.task_resource = CommonTaskResource(
task='tts', model_format='dynamic', inference_mode='online')
def get_model_info(self,
field: str,
model_name: str,
......@@ -214,6 +211,106 @@ class TTSServerExecutor(TTSExecutor):
self.voc_inference.eval()
print("voc done!")
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "online"
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
self.am_upsample = 1
self.voc_upsample = self.executor.voc_config.n_shift
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
class PaddleTTSConnectionHandler:
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.am_block = self.tts_engine.am_block
self.am_pad = self.tts_engine.am_pad
self.voc_block = self.tts_engine.voc_block
self.voc_pad = self.tts_engine.voc_pad
self.am_upsample = self.tts_engine.am_upsample
self.voc_upsample = self.tts_engine.voc_upsample
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
......@@ -242,12 +339,6 @@ class TTSServerExecutor(TTSExecutor):
Model inference and result stored in self.output.
"""
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包
first_flag = 1
......@@ -255,7 +346,7 @@ class TTSServerExecutor(TTSExecutor):
merge_sentences = False
frontend_st = time.time()
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
......@@ -263,7 +354,7 @@ class TTSServerExecutor(TTSExecutor):
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
input_ids = self.executor.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
......@@ -278,19 +369,21 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_csmsc
if am == "fastspeech2_csmsc":
# am
mel = self.am_inference(part_phone_ids)
mel = self.executor.am_inference(part_phone_ids)
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
"voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.executor.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
self.voc_block, self.voc_pad,
self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
......@@ -302,7 +395,8 @@ class TTSServerExecutor(TTSExecutor):
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc":
# am
orig_hs = self.am_inference.encoder_infer(part_phone_ids)
orig_hs = self.executor.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info
mel_len = orig_hs.shape[1]
......@@ -314,13 +408,15 @@ class TTSServerExecutor(TTSExecutor):
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
before_outs = self.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet(
before_outs = self.executor.am_inference.decoder(hs)
after_outs = before_outs + self.executor.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.am_block, self.am_pad,
self.am_upsample)
if i == 0:
mel_streaming = sub_mel
......@@ -337,11 +433,11 @@ class TTSServerExecutor(TTSExecutor):
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk)
sub_wav = self.executor.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
sub_wav = self.depadding(
sub_wav, voc_chunk_num, voc_chunk_id,
self.voc_block, self.voc_pad, self.voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
......@@ -351,9 +447,11 @@ class TTSServerExecutor(TTSExecutor):
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
start = max(
0, voc_chunk_id * self.voc_block - self.voc_pad)
end = min(
(voc_chunk_id + 1) * self.voc_block + self.voc_pad,
mel_len)
else:
logger.error(
......@@ -362,100 +460,6 @@ class TTSServerExecutor(TTSExecutor):
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super().__init__()
def init(self, config: dict) -> bool:
self.config = config
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device is not None:
self.device = self.config.device
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try:
self.executor._init_from_path(
am=self.config.am,
am_config=self.config.am_config,
am_ckpt=self.config.am_ckpt,
am_stat=self.config.am_stat,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_config=self.config.voc_config,
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
......@@ -489,7 +493,7 @@ class TTSEngine(BaseEngine):
wav_list = []
for wav in self.executor.infer(
for wav in self.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
......@@ -505,13 +509,12 @@ class TTSEngine(BaseEngine):
wav_all = np.concatenate(wav_list, axis=0)
duration = len(wav_all) / self.executor.am_config.fs
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(f"first response time: {self.first_response_time} s")
logger.info(f"final response time: {self.final_response_time} s")
logger.info(f"RTF: {self.final_response_time / duration}")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
)
......@@ -14,6 +14,7 @@
import base64
import io
import os
import sys
import time
from typing import Optional
......@@ -35,7 +36,7 @@ from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
......@@ -254,7 +255,7 @@ class TTSServerExecutor(TTSExecutor):
else:
wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all
self._outputs["wav"] = wav_all
class TTSEngine(BaseEngine):
......@@ -272,6 +273,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "inference"
try:
if self.config.am_predictor_conf.device is not None:
......@@ -281,58 +284,59 @@ class TTSEngine(BaseEngine):
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
self.executor._init_from_path(
am=self.config.am,
am_model=self.config.am_model,
am_params=self.config.am_params,
am_sample_rate=self.config.am_sample_rate,
phones_dict=self.config.phones_dict,
tones_dict=self.config.tones_dict,
speaker_dict=self.config.speaker_dict,
voc=self.config.voc,
voc_model=self.config.voc_model,
voc_params=self.config.voc_params,
voc_sample_rate=self.config.voc_sample_rate,
lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, )
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
logger.info("Initialize TTS server engine successfully.")
return True
def warm_up(self):
"""warm up
class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.frontend = self.executor.frontend
self.am_predictor = self.executor.am_predictor
self.voc_predictor = self.executor.voc_predictor
def postprocess(self,
wav,
......@@ -384,8 +388,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64
buf = io.BytesIO()
......@@ -442,7 +449,7 @@ class TTSEngine(BaseEngine):
try:
infer_st = time.time()
self.executor.infer(
self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
......@@ -450,13 +457,16 @@ class TTSEngine(BaseEngine):
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_sample_rate,
target_fs=sample_rate,
volume=volume,
......@@ -464,26 +474,28 @@ class TTSEngine(BaseEngine):
audio_path=save_path)
postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_sample_rate
duration = len(
self._outputs["wav"].numpy()) / self.executor.am_sample_rate
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: paddle inference")
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
......@@ -491,5 +503,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64
......@@ -13,6 +13,7 @@
# limitations under the License.
import base64
import io
import sys
import time
import librosa
......@@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
__all__ = ['TTSEngine']
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor):
......@@ -52,6 +53,8 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
self.lang = self.config.lang
self.engine_type = "python"
try:
if self.config.device is not None:
......@@ -59,12 +62,13 @@ class TTSEngine(BaseEngine):
else:
self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException as e:
except Exception as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
logger.error(e)
return False
try:
......@@ -81,41 +85,35 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat,
lang=self.config.lang)
except BaseException:
except Exception as e:
logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
# warm up
try:
self.warm_up()
logger.info("Warm up successfully.")
except Exception as e:
logger.error("Failed to warm up on tts engine.")
logger.error(e)
return False
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True
def warm_up(self):
"""warm up
class PaddleTTSConnectionHandler(TTSServerExecutor):
def __init__(self, tts_engine):
"""The PaddleSpeech TTS Server Connection Handler
This connection process every tts server request
Args:
tts_engine (TTSEngine): The TTS engine
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info("Start to warm up.")
for i in range(3):
st = time.time()
self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, )
logger.info(
f"The response time of the {i} warm up: {time.time() - st} s")
super().__init__()
logger.info(
"Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine
self.executor = self.tts_engine.executor
self.config = self.tts_engine.config
self.frontend = self.executor.frontend
self.am_inference = self.executor.am_inference
self.voc_inference = self.executor.voc_inference
def postprocess(self,
wav,
......@@ -167,8 +165,11 @@ class TTSEngine(BaseEngine):
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("Failed to transform speed.")
logger.error(e)
sys.exit(-1)
# wav to base64
buf = io.BytesIO()
......@@ -225,24 +226,27 @@ class TTSEngine(BaseEngine):
try:
infer_st = time.time()
self.executor.infer(
self.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_config.fs
duration = len(
self._outputs["wav"].numpy()) / self.executor.am_config.fs
rtf = infer_time / duration
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts infer failed.")
logger.error(e)
sys.exit(-1)
try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
wav=self._outputs["wav"].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
......@@ -254,8 +258,11 @@ class TTSEngine(BaseEngine):
except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
except BaseException:
sys.exit(-1)
except Exception as e:
logger.error("tts postprocess failed.")
logger.error(e)
sys.exit(-1)
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
......@@ -263,10 +270,9 @@ class TTSEngine(BaseEngine):
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
......@@ -274,6 +280,6 @@ class TTSEngine(BaseEngine):
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device))
logger.info("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import sys
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import ASRRequest
from paddlespeech.server.restful.response import ASRResponse
......@@ -68,8 +70,18 @@ def asr(request_body: ASRRequest):
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_engine.run(audio_data)
asr_results = asr_engine.postprocess()
if asr_engine.engine_type == "python":
from paddlespeech.server.engine.asr.python.asr_engine import PaddleASRConnectionHandler
elif asr_engine.engine_type == "inference":
from paddlespeech.server.engine.asr.paddleinference.asr_engine import PaddleASRConnectionHandler
else:
logger.error("Offline asr engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleASRConnectionHandler(asr_engine)
connection_handler.run(audio_data)
asr_results = connection_handler.postprocess()
response = {
"success": True,
......
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import sys
import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import CLSRequest
from paddlespeech.server.restful.response import CLSResponse
......@@ -68,8 +70,18 @@ def cls(request_body: CLSRequest):
engine_pool = get_engine_pool()
cls_engine = engine_pool['cls']
cls_engine.run(audio_data)
cls_results = cls_engine.postprocess(request_body.topk)
if cls_engine.engine_type == "python":
from paddlespeech.server.engine.cls.python.cls_engine import PaddleCLSConnectionHandler
elif cls_engine.engine_type == "inference":
from paddlespeech.server.engine.cls.paddleinference.cls_engine import PaddleCLSConnectionHandler
else:
logger.error("Offline cls engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleCLSConnectionHandler(cls_engine)
connection_handler.run(audio_data)
cls_results = connection_handler.postprocess(request_body.topk)
response = {
"success": True,
......@@ -85,8 +97,11 @@ def cls(request_body: CLSRequest):
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except BaseException:
logger.error(e)
sys.exit(-1)
except Exception as e:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
logger.error(e)
traceback.print_exc()
return response
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from typing import Union
......@@ -99,7 +100,16 @@ def tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
lang, target_sample_rate, duration, wav_base64 = tts_engine.run(
if tts_engine.engine_type == "python":
from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "inference":
from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Offline tts engine only support python or inference.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
lang, target_sample_rate, duration, wav_base64 = connection_handler.run(
text, spk_id, speed, volume, sample_rate, save_path)
response = {
......@@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
return StreamingResponse(tts_engine.run(sentence=text))
if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
connection_handler = PaddleTTSConnectionHandler(tts_engine)
return StreamingResponse(connection_handler.run(sentence=text))
......@@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
connection_handler = None
if tts_engine.engine_type == "online":
from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler
elif tts_engine.engine_type == "online-onnx":
from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler
else:
logger.error("Online tts engine only support online or online-onnx.")
sys.exit(-1)
try:
while True:
# careful here, changed the source code from starlette.websockets
......@@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket):
"signal": "server ready",
"session": session
}
connection_handler = PaddleTTSConnectionHandler(tts_engine)
await websocket.send_json(resp)
# end request
elif message['signal'] == 'end':
connection_handler = None
resp = {
"status": 0,
"signal": "connection will be closed",
......@@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket):
# speech synthesis request
elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
sentence = connection_handler.preprocess(
text_bese64=text_bese64)
# run
wav_generator = tts_engine.run(sentence)
wav_generator = connection_handler.run(sentence)
while True:
try:
......
......@@ -28,7 +28,7 @@ StartService(){
ClientTest_http(){
for ((i=1; i<=3;i++))
do
paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。"
paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --port $port
((http_test_times+=1))
done
}
......@@ -36,7 +36,7 @@ ClientTest_http(){
ClientTest_ws(){
for ((i=1; i<=3;i++))
do
paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --protocol websocket
paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --protocol websocket --port $port
((ws_test_times+=1))
done
}
......@@ -54,7 +54,7 @@ GetTestResult_http() {
GetTestResult_ws() {
# Determine if the test was successful
ws_response_success_time=$(cat $log/server.log.wf | grep "Complete the transmission of audio streams" -c)
ws_response_success_time=$(cat $log/server.log.wf | grep "Complete the synthesis of the audio streams" -c)
if (( $ws_response_success_time == $ws_test_times )) ; then
echo "Testing successfully. $info" | tee -a $log/test_result.log
else
......@@ -313,4 +313,4 @@ cat $log/test_result.log
# Restoring conf is the same as demos/speech_server
cp ./tts_online_application.yaml ./conf/application.yaml -rf
sleep 2s
\ No newline at end of file
sleep 2s
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册