未验证 提交 b7126f09 编写于 作者: W WilliamZhang06 提交者: GitHub

Merge pull request #1 from PaddlePaddle/server

Server
# This is the parameter configuration file for TTS server.
##################################################################
# TTS SERVER SETTING #
##################################################################
host: '0.0.0.0'
port: 8692
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
# 'fastspeech2_vctk']
##################################################################
am: 'fastspeech2_csmsc'
am_config:
am_ckpt:
am_stat:
phones_dict:
tones_dict:
speaker_dict:
spk_id: 0
##################################################################
# VOCODER SETTING #
# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
# 'pwgan_vctk', 'mb_melgan_csmsc']
##################################################################
voc: 'pwgan_csmsc'
voc_config:
voc_ckpt:
voc_stat:
##################################################################
# OTHERS #
##################################################################
lang: 'zh'
device: paddle.get_device()
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import base64
import librosa
import numpy as np
import soundfile as sf
import yaml
from engine.base_engine import BaseEngine
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from utils.errors import ErrorCode
from utils.exception import ServerBaseException
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
'--conf',
type=str,
default='./conf/tts/tts.yaml',
help='Configuration parameters.')
class TTSEngine(BaseEngine):
"""TTS server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
self.executor = TTSServerExecutor()
config_path = self.executor.parser.parse_args().conf
with open(config_path, 'rt') as f:
self.conf_dict = yaml.safe_load(f)
self.executor._init_from_path(
am=self.conf_dict["am"],
am_config=self.conf_dict["am_config"],
am_ckpt=self.conf_dict["am_ckpt"],
am_stat=self.conf_dict["am_stat"],
phones_dict=self.conf_dict["phones_dict"],
tones_dict=self.conf_dict["tones_dict"],
speaker_dict=self.conf_dict["speaker_dict"],
voc=self.conf_dict["voc"],
voc_config=self.conf_dict["voc_config"],
voc_ckpt=self.conf_dict["voc_ckpt"],
voc_stat=self.conf_dict["voc_stat"],
lang=self.conf_dict["lang"])
logger.info("Initialize TTS server engine successfully.")
def postprocess(self,
wav,
original_fs: int,
target_fs: int=16000,
volume: float=1.0,
speed: float=1.0,
audio_path: str=None,
audio_format: str="wav"):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
"""
# transform sample_rate
if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs
wav_tar_fs = wav
else:
wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs)
# transform volume
wav_vol = wav_tar_fs * volume
# transform speed
# TODO
target_wav = wav_vol.reshape(-1, 1)
# save audio
if audio_path is not None:
sf.write(audio_path, target_wav, target_fs)
logger.info('Wave file has been generated: {}'.format(audio_path))
# wav to base64
base64_bytes = base64.b64encode(target_wav)
base64_string = base64_bytes.decode('utf-8')
wav_base64 = base64_string
return target_fs, wav_base64
def run(self,
sentence: str,
spk_id: int=0,
speed: float=1.0,
volume: float=1.0,
sample_rate: int=0,
save_path: str=None,
audio_format: str="wav"):
lang = self.conf_dict["lang"]
try:
self.executor.infer(
text=sentence,
lang=lang,
am=self.conf_dict["am"],
spk_id=spk_id)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.")
try:
target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs,
target_fs=sample_rate,
volume=volume,
speed=speed,
audio_path=save_path,
audio_format=audio_format)
except:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.")
return lang, target_sample_rate, wav_base64
......@@ -15,11 +15,12 @@ import argparse
import uvicorn
import yaml
from engine.asr.python.asr_engine import ASREngine
from engine.tts.python.tts_engine import TTSEngine
from fastapi import FastAPI
from restful.api import router as api_router
from utils.log import logger
from paddlespeech.cli.log import logger
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
......@@ -31,7 +32,8 @@ def init(args):
app.include_router(api_router)
# engine single
ASR_ENGINE = ASREngine("asr")
TTS_ENGINE = TTSEngine()
# todo others
......@@ -56,7 +58,8 @@ if __name__ == "__main__":
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
default="./conf/tts/tts.yaml")
parser.add_argument(
"--log_file",
action="store",
......
......@@ -13,9 +13,9 @@
# limitations under the License.
from fastapi import APIRouter
from .asr_api import router as asr_router
from .tts_api import router as tts_router
#from .asr_api import router as asr_router
router = APIRouter()
router.include_router(asr_router)
#router.include_router(asr_router)
router.include_router(tts_router)
......@@ -16,7 +16,8 @@ from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRRequest, TTSRequest']
__all__ = ['ASRRequest', 'TTSRequest']
#****************************************************************************************/
......@@ -44,13 +45,26 @@ class ASRRequest(BaseModel):
#************************************ TTS request ***************************************/
#****************************************************************************************/
class TTSRequest(BaseModel):
"""
"""TTS request
request body example
{
"audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
"audio_format": "wav",
"sample_rate": 16000,
"lang ": "zh_cn",
"ptt ":false
"text": "你好,欢迎使用百度飞桨语音合成服务。",
"spk_id": 0,
"speed": 1.0,
"volume": 1.0,
"sample_rate": 0,
"tts_audio_path": "./tts.wav",
"audio_format": "wav"
}
"""
text: str
spk_id: int = 0
speed: float = 1.0
volume: float = 1.0
sample_rate: int = 0
save_path: str = None
audio_format: str = "wav"
......@@ -16,7 +16,7 @@ from typing import Optional
from pydantic import BaseModel
__all__ = ['ASRResponse']
__all__ = ['ASRResponse', 'TTSResponse']
class Message(BaseModel):
......@@ -53,3 +53,36 @@ class ASRResponse(BaseModel):
#****************************************************************************************/
#************************************ TTS response **************************************/
#****************************************************************************************/
class TTSResult(BaseModel):
lang: str = "zh"
sample_rate: int
spk_id: int = 0
speed: float = 1.0
volume: float = 1.0
save_path: str = None
audio: str
class TTSResponse(BaseModel):
"""
response example
{
"success": true,
"code": 200,
"message": {
"description": "success"
},
"result": {
"lang": "zh",
"sample_rate": 24000,
"speed": 1.0,
"volume": 1.0,
"audio": "LTI1OTIuNjI1OTUwMzQsOTk2OS41NDk4...",
"save_path": "./tts.wav"
}
}
"""
success: bool
code: int
message: Message
result: TTSResult
......@@ -11,8 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from engine.tts.python.tts_engine import TTSEngine
from fastapi import APIRouter
from .request import TTSRequest
from .response import TTSResponse
from utils.errors import ErrorCode
from utils.errors import ErrorMsg
from utils.errors import failed_response
from utils.exception import ServerBaseException
router = APIRouter()
......@@ -24,6 +33,76 @@ def help():
Returns:
json: [description]
"""
return {'hello': 'world'}
json_body = {
"success": "True",
"code": 0,
"message": {
"global": "success"
},
"result": {
"description": "tts server",
"text": "sentence to be synthesized",
"audio": "the base64 of audio"
}
}
return json_body
@router.post("/paddlespeech/tts", response_model=TTSResponse)
def tts(request_body: TTSRequest):
"""tts api
Args:
request_body (TTSRequest): [description]
Returns:
json: [description]
"""
# json to dict
item_dict = request_body.dict()
sentence = item_dict['text']
spk_id = item_dict['spk_id']
speed = item_dict['speed']
volume = item_dict['volume']
sample_rate = item_dict['sample_rate']
save_path = item_dict['save_path']
audio_format = item_dict['audio_format']
# Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \
sample_rate not in [0, 16000, 8000] or \
audio_format not in ["pcm", "wav"]:
return failed_response(ErrorCode.SERVER_PARAM_ERR)
# single
tts_engine = TTSEngine()
# run
try:
lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path,
audio_format)
except ServerBaseException as e:
response = failed_response(e.error_code, e.msg)
except:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc()
json_body = {
"success": True,
"code": 200,
"message": {
"description": "success."
},
"result": {
"lang": lang,
"spk_id": spk_id,
"speed": speed,
"volume": volume,
"sample_rate": target_sample_rate,
"save_path": save_path,
"audio": wav_base64
}
}
return json_body
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import IntEnum
from fastapi import Response
class ErrorCode(IntEnum):
SERVER_OK = 200 # success.
SERVER_PARAM_ERR = 400 # Input parameters are not valid.
SERVER_TASK_NOT_EXIST = 404 # Task is not exist.
SERVER_INTERNAL_ERR = 500 # Internal error.
SERVER_NETWORK_ERR = 502 # Network exception.
SERVER_UNKOWN_ERR = 509 # Unknown error occurred.
ErrorMsg = {
ErrorCode.SERVER_OK: "success.",
ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.",
ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.",
ErrorCode.SERVER_INTERNAL_ERR: "Internal error.",
ErrorCode.SERVER_NETWORK_ERR: "Network exception.",
ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred."
}
def failed_response(code, msg=""):
"""Interface call failure response
Args:
code (int): error code number
msg (str, optional): Interface call failure information. Defaults to "".
Returns:
Response (json): failure json information.
"""
if not msg:
msg = ErrorMsg.get(code, "Unknown error occurred.")
res = {"success": False, "code": int(code), "message": {"global": msg}}
return Response(content=json.dumps(res), media_type="application/json")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from utils.errors import ErrorMsg
class ServerBaseException(Exception):
""" Server Base exception
"""
def __init__(self, error_code, msg=None):
#if msg:
#log.error(msg)
msg = msg if msg else ErrorMsg.get(error_code, "")
super(ServerBaseException, self).__init__(error_code, msg)
self.error_code = error_code
self.msg = msg
traceback.print_exc()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册