提交 48e61075 编写于 作者: H huangyuxin

Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into unit_test

...@@ -9,9 +9,17 @@ port: 8090 ...@@ -9,9 +9,17 @@ port: 8090
################################################################## ##################################################################
# CONFIG FILE # # CONFIG FILE #
################################################################## ##################################################################
# add engine type (Options: asr, tts) and config file here. # The engine_type of speech task needs to keep the same type as the config file of speech task.
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
#
# add engine type (Options: python, inference)
engine_type:
asr: 'inference'
tts: 'inference'
# add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend: engine_backend:
asr: 'conf/asr/asr.yaml' asr: 'conf/asr/asr_pd.yaml'
tts: 'conf/tts/tts.yaml' tts: 'conf/tts/tts_pd.yaml'
model: 'conformer_wenetspeech' model: 'conformer_wenetspeech'
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path: # [optional]
ckpt_path: ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: False force_yes: True
device: 'cpu' # set 'gpu:id' or 'cpu'
# This is the parameter configuration file for ASR server.
# These are the static models that support paddle inference.
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['deepspeech2offline_aishell'] TODO
##################################################################
model_type: 'deepspeech2offline_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True
##################################################################
# OTHERS #
##################################################################
...@@ -29,4 +29,4 @@ voc_stat: ...@@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'gpu:2' device: 'cpu' # set 'gpu:id' or 'cpu'
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'] # am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
################################################################## ##################################################################
am: 'fastspeech2_csmsc' am: 'fastspeech2_csmsc'
am_model: # the pdmodel file of am static model am_model: # the pdmodel file of your am static model (XX.pdmodel)
am_params: # the pdiparams file of am static model am_params: # the pdiparams file of your am static model (XX.pdipparams)
am_sample_rate: 24000 am_sample_rate: 24000
phones_dict: phones_dict:
tones_dict: tones_dict:
...@@ -15,9 +15,9 @@ speaker_dict: ...@@ -15,9 +15,9 @@ speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
use_gpu: True device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
...@@ -25,17 +25,16 @@ am_predictor_conf: ...@@ -25,17 +25,16 @@ am_predictor_conf:
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc'] # voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
################################################################## ##################################################################
voc: 'pwgan_csmsc' voc: 'pwgan_csmsc'
voc_model: # the pdmodel file of vocoder static model voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
voc_params: # the pdiparams file of vocoder static model voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000
voc_predictor_conf: voc_predictor_conf:
use_gpu: True device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device()
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import uvicorn import uvicorn
import yaml
from fastapi import FastAPI from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.engine.engine_pool import init_engine_pool
......
...@@ -48,8 +48,9 @@ class TTSClientExecutor(BaseExecutor): ...@@ -48,8 +48,9 @@ class TTSClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--input', '--input',
type=str, type=str,
default="你好,欢迎使用语音合成服务", default=None,
help='A sentence to be synthesized.') help='Text to be synthesized.',
required=True)
self.parser.add_argument( self.parser.add_argument(
'--spk_id', type=int, default=0, help='Speaker id') '--spk_id', type=int, default=0, help='Speaker id')
self.parser.add_argument( self.parser.add_argument(
...@@ -123,7 +124,7 @@ class TTSClientExecutor(BaseExecutor): ...@@ -123,7 +124,7 @@ class TTSClientExecutor(BaseExecutor):
logger.info("RTF: %f " % (time_consume / duration)) logger.info("RTF: %f " % (time_consume / duration))
return True return True
except: except BaseException:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
return False return False
...@@ -163,7 +164,7 @@ class TTSClientExecutor(BaseExecutor): ...@@ -163,7 +164,7 @@ class TTSClientExecutor(BaseExecutor):
print("Audio duration: %f s." % (duration)) print("Audio duration: %f s." % (duration))
print("Response time: %f s." % (time_consume)) print("Response time: %f s." % (time_consume))
print("RTF: %f " % (time_consume / duration)) print("RTF: %f " % (time_consume / duration))
except: except BaseException:
print("Failed to synthesized audio.") print("Failed to synthesized audio.")
...@@ -181,8 +182,9 @@ class ASRClientExecutor(BaseExecutor): ...@@ -181,8 +182,9 @@ class ASRClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--input', '--input',
type=str, type=str,
default="./paddlespeech/server/tests/16_audio.wav", default=None,
help='Audio file to be recognized') help='Audio file to be recognized',
required=True)
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument( self.parser.add_argument(
...@@ -209,7 +211,7 @@ class ASRClientExecutor(BaseExecutor): ...@@ -209,7 +211,7 @@ class ASRClientExecutor(BaseExecutor):
logger.info(r.json()) logger.info(r.json())
logger.info("time cost %f s." % (time_end - time_start)) logger.info("time cost %f s." % (time_end - time_start))
return True return True
except: except BaseException:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
return False return False
...@@ -240,5 +242,5 @@ class ASRClientExecutor(BaseExecutor): ...@@ -240,5 +242,5 @@ class ASRClientExecutor(BaseExecutor):
time_end = time.time() time_end = time.time()
print(r.json()) print(r.json())
print("time cost %f s." % (time_end - time_start)) print("time cost %f s." % (time_end - time_start))
except: except BaseException:
print("Failed to speech recognition.") print("Failed to speech recognition.")
\ No newline at end of file
...@@ -20,7 +20,7 @@ from fastapi import FastAPI ...@@ -20,7 +20,7 @@ from fastapi import FastAPI
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..util import cli_server_register from ..util import cli_server_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.server.engine.engine_factory import EngineFactory from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
...@@ -41,7 +41,8 @@ class ServerExecutor(BaseExecutor): ...@@ -41,7 +41,8 @@ class ServerExecutor(BaseExecutor):
"--config_file", "--config_file",
action="store", action="store",
help="yaml file of the app", help="yaml file of the app",
default="./conf/application.yaml") default=None,
required=True)
self.parser.add_argument( self.parser.add_argument(
"--log_file", "--log_file",
...@@ -51,8 +52,10 @@ class ServerExecutor(BaseExecutor): ...@@ -51,8 +52,10 @@ class ServerExecutor(BaseExecutor):
def init(self, config) -> bool: def init(self, config) -> bool:
"""system initialization """system initialization
Args: Args:
config (CfgNode): config object config (CfgNode): config object
Returns: Returns:
bool: bool:
""" """
...@@ -61,13 +64,8 @@ class ServerExecutor(BaseExecutor): ...@@ -61,13 +64,8 @@ class ServerExecutor(BaseExecutor):
api_router = setup_router(api_list) api_router = setup_router(api_list)
app.include_router(api_router) app.include_router(api_router)
# init engine if not init_engine_pool(config):
engine_pool = [] return False
for engine in config.engine_backend:
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
if not engine_pool[-1].init(
config_file=config.engine_backend[engine]):
return False
return True return True
......
...@@ -9,12 +9,17 @@ port: 8090 ...@@ -9,12 +9,17 @@ port: 8090
################################################################## ##################################################################
# CONFIG FILE # # CONFIG FILE #
################################################################## ##################################################################
# The engine_type of speech task needs to keep the same type as the config file of speech task.
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
#
# add engine type (Options: python, inference) # add engine type (Options: python, inference)
engine_type: engine_type:
asr: 'inference' asr: 'python'
# tts: 'inference' tts: 'python'
# add engine backend type (Options: asr, tts) and config file here. # add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend: engine_backend:
asr: 'conf/asr/asr_pd.yaml' asr: 'conf/asr/asr.yaml'
#tts: 'conf/tts/tts_pd.yaml' tts: 'conf/tts/tts.yaml'
...@@ -5,3 +5,4 @@ cfg_path: # [optional] ...@@ -5,3 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: True force_yes: True
device: 'cpu' # set 'gpu:id' or 'cpu'
...@@ -15,7 +15,7 @@ decode_method: ...@@ -15,7 +15,7 @@ decode_method:
force_yes: True force_yes: True
am_predictor_conf: am_predictor_conf:
use_gpu: True device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: True
switch_ir_optim: True switch_ir_optim: True
......
...@@ -29,4 +29,4 @@ voc_stat: ...@@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device() device: 'cpu' # set 'gpu:id' or 'cpu'
\ No newline at end of file
...@@ -6,18 +6,18 @@ ...@@ -6,18 +6,18 @@
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'] # am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
################################################################## ##################################################################
am: 'fastspeech2_csmsc' am: 'fastspeech2_csmsc'
am_model: # the pdmodel file of am static model am_model: # the pdmodel file of your am static model (XX.pdmodel)
am_params: # the pdiparams file of am static model am_params: # the pdiparams file of your am static model (XX.pdipparams)
am_sample_rate: 24000 am_sample_rate: 24000 # must match the model
phones_dict: phones_dict:
tones_dict: tones_dict:
speaker_dict: speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
use_gpu: True device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
...@@ -25,17 +25,16 @@ am_predictor_conf: ...@@ -25,17 +25,16 @@ am_predictor_conf:
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc'] # voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
################################################################## ##################################################################
voc: 'pwgan_csmsc' voc: 'pwgan_csmsc'
voc_model: # the pdmodel file of vocoder static model voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
voc_params: # the pdiparams file of vocoder static model voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000 #must match the model
voc_predictor_conf: voc_predictor_conf:
use_gpu: True device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device()
...@@ -13,31 +13,24 @@ ...@@ -13,31 +13,24 @@
# limitations under the License. # limitations under the License.
import io import io
import os import os
from typing import List
from typing import Optional from typing import Optional
from typing import Union
import librosa
import paddle import paddle
import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine'] __all__ = ['ASREngine']
pretrained_models = { pretrained_models = {
"deepspeech2offline_aishell-zh-16k": { "deepspeech2offline_aishell-zh-16k": {
'url': 'url':
...@@ -143,7 +136,6 @@ class ASRServerExecutor(ASRExecutor): ...@@ -143,7 +136,6 @@ class ASRServerExecutor(ASRExecutor):
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None)) grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
...@@ -161,9 +153,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -161,9 +153,8 @@ class ASRServerExecutor(ASRExecutor):
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
output_data = run_model( output_data = run_model(self.am_predictor,
self.am_predictor, [audio.numpy(), audio_len.numpy()])
[audio.numpy(), audio_len.numpy()])
probs = output_data[0] probs = output_data[0]
eouts_len = output_data[1] eouts_len = output_data[1]
...@@ -208,14 +199,14 @@ class ASREngine(BaseEngine): ...@@ -208,14 +199,14 @@ class ASREngine(BaseEngine):
paddle.set_device(paddle.get_device()) paddle.set_device(paddle.get_device())
self.executor._init_from_path( self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
am_params=self.config.am_params, am_params=self.config.am_params,
lang=self.config.lang, lang=self.config.lang,
sample_rate=self.config.sample_rate, sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path, cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method, decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf) am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
...@@ -230,7 +221,8 @@ class ASREngine(BaseEngine): ...@@ -230,7 +221,8 @@ class ASREngine(BaseEngine):
io.BytesIO(audio_data), self.config.sample_rate, io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes): self.config.force_yes):
logger.info("start running asr engine") logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data)) self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.executor.infer(self.config.model_type) self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr. self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine") logger.info("end inferring asr engine")
......
...@@ -12,21 +12,11 @@ ...@@ -12,21 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle import paddle
import soundfile
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
...@@ -63,7 +53,10 @@ class ASREngine(BaseEngine): ...@@ -63,7 +53,10 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(paddle.get_device()) if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate, self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method, self.config.cfg_path, self.config.decode_method,
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import Any
from typing import List
from typing import Union from typing import Union
from pattern_singleton import Singleton from pattern_singleton import Singleton
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']
......
...@@ -29,8 +29,10 @@ def init_engine_pool(config) -> bool: ...@@ -29,8 +29,10 @@ def init_engine_pool(config) -> bool:
""" """
global ENGINE_POOL global ENGINE_POOL
for engine in config.engine_backend: for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine]) ENGINE_POOL[engine] = EngineFactory.get_engine(
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]): engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(
config_file=config.engine_backend[engine]):
return False return False
return True return True
...@@ -344,7 +344,6 @@ class TTSEngine(BaseEngine): ...@@ -344,7 +344,6 @@ class TTSEngine(BaseEngine):
try: try:
self.config = get_config(config_file) self.config = get_config(config_file)
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
am_model=self.config.am_model, am_model=self.config.am_model,
...@@ -361,8 +360,8 @@ class TTSEngine(BaseEngine): ...@@ -361,8 +360,8 @@ class TTSEngine(BaseEngine):
am_predictor_conf=self.config.am_predictor_conf, am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, ) voc_predictor_conf=self.config.voc_predictor_conf, )
except: except BaseException:
logger.info("Initialize TTS server engine Failed.") logger.error("Initialize TTS server engine Failed.")
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
...@@ -406,11 +405,13 @@ class TTSEngine(BaseEngine): ...@@ -406,11 +405,13 @@ class TTSEngine(BaseEngine):
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
except: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \ "Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
...@@ -463,9 +464,11 @@ class TTSEngine(BaseEngine): ...@@ -463,9 +464,11 @@ class TTSEngine(BaseEngine):
try: try:
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try: try:
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
...@@ -475,8 +478,10 @@ class TTSEngine(BaseEngine): ...@@ -475,8 +478,10 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64
...@@ -54,7 +54,10 @@ class TTSEngine(BaseEngine): ...@@ -54,7 +54,10 @@ class TTSEngine(BaseEngine):
try: try:
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(self.config.device) if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
...@@ -69,8 +72,8 @@ class TTSEngine(BaseEngine): ...@@ -69,8 +72,8 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt, voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat, voc_stat=self.config.voc_stat,
lang=self.config.lang) lang=self.config.lang)
except: except BaseException:
logger.info("Initialize TTS server engine Failed.") logger.error("Initialize TTS server engine Failed.")
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
...@@ -114,10 +117,13 @@ class TTSEngine(BaseEngine): ...@@ -114,10 +117,13 @@ class TTSEngine(BaseEngine):
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
except: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Can not install soxbindings on your system.") "Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
...@@ -170,9 +176,11 @@ class TTSEngine(BaseEngine): ...@@ -170,9 +176,11 @@ class TTSEngine(BaseEngine):
try: try:
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try: try:
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
...@@ -182,8 +190,10 @@ class TTSEngine(BaseEngine): ...@@ -182,8 +190,10 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import base64 import base64
import traceback import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
...@@ -83,7 +84,7 @@ def asr(request_body: ASRRequest): ...@@ -83,7 +84,7 @@ def asr(request_body: ASRRequest):
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except: except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc() traceback.print_exc()
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
......
...@@ -11,9 +11,6 @@ ...@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse'] __all__ = ['ASRResponse', 'TTSResponse']
......
...@@ -16,7 +16,7 @@ from typing import Union ...@@ -16,7 +16,7 @@ from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import TTSRequest from paddlespeech.server.restful.request import TTSRequest
from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.restful.response import TTSResponse from paddlespeech.server.restful.response import TTSResponse
...@@ -60,28 +60,41 @@ def tts(request_body: TTSRequest): ...@@ -60,28 +60,41 @@ def tts(request_body: TTSRequest):
Returns: Returns:
json: [description] json: [description]
""" """
# json to dict # get params
item_dict = request_body.dict() text = request_body.text
sentence = item_dict['text'] spk_id = request_body.spk_id
spk_id = item_dict['spk_id'] speed = request_body.speed
speed = item_dict['speed'] volume = request_body.volume
volume = item_dict['volume'] sample_rate = request_body.sample_rate
sample_rate = item_dict['sample_rate'] save_path = request_body.save_path
save_path = item_dict['save_path']
# Check parameters # Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ if speed <= 0 or speed > 3:
sample_rate not in [0, 16000, 8000] or \ return failed_response(
(save_path is not None and not save_path.endswith("pcm") and not save_path.endswith("wav")): ErrorCode.SERVER_PARAM_ERR,
return failed_response(ErrorCode.SERVER_PARAM_ERR) "invalid speed value, the value should be between 0 and 3.")
if volume <= 0 or volume > 3:
# single return failed_response(
tts_engine = TTSEngine() ErrorCode.SERVER_PARAM_ERR,
"invalid volume value, the value should be between 0 and 3.")
if sample_rate not in [0, 16000, 8000]:
return failed_response(
ErrorCode.SERVER_PARAM_ERR,
"invalid sample_rate value, the choice of value is 0, 8000, 16000.")
if save_path is not None and not save_path.endswith(
"pcm") and not save_path.endswith("wav"):
return failed_response(
ErrorCode.SERVER_PARAM_ERR,
"invalid save_path, saved audio formats support pcm and wav")
# run # run
try: try:
# get single engine from engine pool
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
lang, target_sample_rate, wav_base64 = tts_engine.run( lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path) text, spk_id, speed, volume, sample_rate, save_path)
response = { response = {
"success": True, "success": True,
...@@ -101,7 +114,7 @@ def tts(request_body: TTSRequest): ...@@ -101,7 +114,7 @@ def tts(request_body: TTSRequest):
} }
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except: except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc() traceback.print_exc()
......
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the # See the License for the
import requests import base64
import json import json
import time import time
import base64
import io import requests
def readwav2base64(wav_file): def readwav2base64(wav_file):
...@@ -34,23 +34,23 @@ def main(): ...@@ -34,23 +34,23 @@ def main():
url = "http://127.0.0.1:8090/paddlespeech/asr" url = "http://127.0.0.1:8090/paddlespeech/asr"
# start Timestamp # start Timestamp
time_start=time.time() time_start = time.time()
test_audio_dir = "./16_audio.wav" test_audio_dir = "./16_audio.wav"
audio = readwav2base64(test_audio_dir) audio = readwav2base64(test_audio_dir)
data = { data = {
"audio": audio, "audio": audio,
"audio_format": "wav", "audio_format": "wav",
"sample_rate": 16000, "sample_rate": 16000,
"lang": "zh_cn", "lang": "zh_cn",
} }
r = requests.post(url=url, data=json.dumps(data)) r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp # ending Timestamp
time_end=time.time() time_end = time.time()
print('time cost',time_end - time_start, 's') print('time cost', time_end - time_start, 's')
print(r.json()) print(r.json())
......
...@@ -25,6 +25,7 @@ import soundfile ...@@ -25,6 +25,7 @@ import soundfile
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
# Request and response # Request and response
def tts_client(args): def tts_client(args):
""" Request and response """ Request and response
...@@ -99,5 +100,5 @@ if __name__ == "__main__": ...@@ -99,5 +100,5 @@ if __name__ == "__main__":
print("Inference time: %f" % (time_consume)) print("Inference time: %f" % (time_consume))
print("The duration of synthesized audio: %f" % (duration)) print("The duration of synthesized audio: %f" % (duration))
print("The RTF is: %f" % (rtf)) print("The RTF is: %f" % (rtf))
except: except BaseException:
print("Failed to synthesized audio.") print("Failed to synthesized audio.")
...@@ -219,7 +219,7 @@ class ConfigCache: ...@@ -219,7 +219,7 @@ class ConfigCache:
try: try:
cfg = yaml.load(file, Loader=yaml.FullLoader) cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg) self._data.update(cfg)
except: except BaseException:
self.flush() self.flush()
@property @property
......
...@@ -41,8 +41,9 @@ def init_predictor(model_dir: Optional[os.PathLike]=None, ...@@ -41,8 +41,9 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
config = Config(model_file, params_file) config = Config(model_file, params_file)
config.enable_memory_optim() config.enable_memory_optim()
if predictor_conf["use_gpu"]: if "gpu" in predictor_conf["device"]:
config.enable_use_gpu(1000, 0) gpu_id = predictor_conf["device"].split(":")[-1]
config.enable_use_gpu(1000, int(gpu_id))
if predictor_conf["enable_mkldnn"]: if predictor_conf["enable_mkldnn"]:
config.enable_mkldnn() config.enable_mkldnn()
if predictor_conf["switch_ir_optim"]: if predictor_conf["switch_ir_optim"]:
......
...@@ -27,46 +27,52 @@ from setuptools.command.install import install ...@@ -27,46 +27,52 @@ from setuptools.command.install import install
HERE = Path(os.path.abspath(os.path.dirname(__file__))) HERE = Path(os.path.abspath(os.path.dirname(__file__)))
VERSION = '0.1.1' VERSION = '0.1.2'
base = [
"editdistance",
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa==0.8.1",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
]
server = [
"fastapi",
"uvicorn",
"pattern_singleton",
]
requirements = { requirements = {
"install": [ "install":
"editdistance", base + server,
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
# fastapi server
"fastapi",
"uvicorn",
],
"develop": [ "develop": [
"ConfigArgParse", "ConfigArgParse",
"coverage", "coverage",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册