diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4ce9605d62a0c411840f9f861a5f251b146110ab --- /dev/null +++ b/paddlespeech/server/README.md @@ -0,0 +1,33 @@ +# PaddleSpeech Server Command Line + +([简体中文](./README_cn.md)|English) + + The simplest approach to use PaddleSpeech Server including server and client. + + ## PaddleSpeech Server + ### Help + ```bash + paddlespeech_server help + ``` + ### Start the server + First set the service-related configuration parameters, similar to `./conf/application.yaml`, + Then start the service: + ```bash + paddlespeech_server start --config_file ./conf/application.yaml + ``` + + ## PaddleSpeech Client + ### Help + ```bash + paddlespeech_client help + ``` + ### Access speech recognition services + ``` + paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./tests/16_audio.wav + ``` + + ### Access text to speech services + ```bash + paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "你好,欢迎使用百度飞桨深度学习框架!" --output output.wav + ``` + diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..2dfd9474ba6490dedbb8d984c5ba9810506fa415 --- /dev/null +++ b/paddlespeech/server/README_cn.md @@ -0,0 +1,32 @@ +# PaddleSpeech Server 命令行工具 + +(简体中文|[English](./README.md)) + +它提供了最简便的方式调用 PaddleSpeech 语音服务用一行命令就可以轻松启动服务和调用服务。 + + ## 服务端命令行使用 + ### 帮助 + ```bash + paddlespeech_server help + ``` + ### 启动服务 + 首先设置服务相关配置文件,类似于 `./conf/application.yaml`,同时设置服务配置中的语音任务模型相关配置,类似于 `./conf/tts/tts.yaml`。 + 然后启动服务: + ```bash + paddlespeech_server start --config_file ./conf/application.yaml + ``` + + ## 客户端命令行使用 + ### 帮助 + ```bash + paddlespeech_client help + ``` + ### 访问语音识别服务 + ``` + paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input input_16k.wav + ``` + + ### 访问语音合成服务 + ```bash + paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "你好,欢迎使用百度飞桨深度学习框架!" --output output.wav + ``` diff --git a/paddlespeech/server/__init__.py b/paddlespeech/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..384061ddae2c089c7dfb245d17ac46d2371f0b5f --- /dev/null +++ b/paddlespeech/server/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _locale + +from .base_commands import ClientBaseCommand +from .base_commands import ClientHelpCommand +from .base_commands import ServerBaseCommand +from .base_commands import ServerHelpCommand +from .bin.paddlespeech_client import ASRClientExecutor +from .bin.paddlespeech_client import TTSClientExecutor +from .bin.paddlespeech_server import ServerExecutor + +_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) diff --git a/paddlespeech/server/base_commands.py b/paddlespeech/server/base_commands.py new file mode 100644 index 0000000000000000000000000000000000000000..d1239297d47c88d6169c2622ff89b568a9292c68 --- /dev/null +++ b/paddlespeech/server/base_commands.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from .entry import client_commands +from .entry import server_commands +from .util import cli_client_register +from .util import cli_server_register +from .util import get_client_command +from .util import get_server_command + +__all__ = [ + 'ServerBaseCommand', + 'ServerHelpCommand', + 'ClientBaseCommand', + 'ClientHelpCommand', +] + + +@cli_server_register(name='paddlespeech_server') +class ServerBaseCommand: + def execute(self, argv: List[str]) -> bool: + help = get_server_command('paddlespeech_server.help') + return help().execute(argv) + + +@cli_server_register( + name='paddlespeech_server.help', description='Show help for commands.') +class ServerHelpCommand: + def execute(self, argv: List[str]) -> bool: + msg = 'Usage:\n' + msg += ' paddlespeech_server \n\n' + msg += 'Commands:\n' + for command, detail in server_commands['paddlespeech_server'].items(): + if command.startswith('_'): + continue + + if '_description' not in detail: + continue + msg += ' {:<15} {}\n'.format(command, + detail['_description']) + + print(msg) + return True + + +@cli_client_register(name='paddlespeech_client') +class ClientBaseCommand: + def execute(self, argv: List[str]) -> bool: + help = get_client_command('paddlespeech_client.help') + return help().execute(argv) + + +@cli_client_register( + name='paddlespeech_client.help', description='Show help for commands.') +class ClientHelpCommand: + def execute(self, argv: List[str]) -> bool: + msg = 'Usage:\n' + msg += ' paddlespeech_client \n\n' + msg += 'Commands:\n' + for command, detail in client_commands['paddlespeech_client'].items(): + if command.startswith('_'): + continue + + if '_description' not in detail: + continue + msg += ' {:<15} {}\n'.format(command, + detail['_description']) + + print(msg) + return True diff --git a/paddlespeech/server/bin/__init__.py b/paddlespeech/server/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd75747f79948ea42229b8c164174dbe4240d4b1 --- /dev/null +++ b/paddlespeech/server/bin/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .paddlespeech_client import ASRClientExecutor +from .paddlespeech_client import TTSClientExecutor +from .paddlespeech_server import ServerExecutor diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py new file mode 100644 index 0000000000000000000000000000000000000000..af51f3f2e7eabb0cdfe8a9fff830e2cd7d00280f --- /dev/null +++ b/paddlespeech/server/bin/main.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import uvicorn +import yaml +from fastapi import FastAPI + +from paddlespeech.server.engine.engine_factory import EngineFactory +from paddlespeech.server.restful.api import setup_router +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.log import logger + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") + + +def init(config): + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: + """ + # init api + api_list = list(config.engine_backend) + api_router = setup_router(api_list) + app.include_router(api_router) + + # init engine + engine_pool = [] + for engine in config.engine_backend: + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init(config_file=config.engine_backend[engine]): + return False + + return True + + +def main(args): + """main function""" + + config = get_config(args.config_file) + + if init(config): + uvicorn.run(app, host=config.host, port=config.port, debug=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default="./conf/application.yaml") + + parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + args = parser.parse_args() + + main(args) diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py new file mode 100644 index 0000000000000000000000000000000000000000..0e030da9b9728eb6c5225d8fe8e06e64c877ae47 --- /dev/null +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import io +import json +import os +import random +import time +from typing import List + +import numpy as np +import requests +import soundfile + +from ..executor import BaseExecutor +from ..util import cli_client_register +from paddlespeech.server.utils.audio_process import wav2pcm +from paddlespeech.server.utils.util import wav2base64 + +__all__ = ['TTSClientExecutor', 'ASRClientExecutor'] + + +@cli_client_register( + name='paddlespeech_client.tts', description='visit tts service') +class TTSClientExecutor(BaseExecutor): + def __init__(self): + super().__init__() + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--input', + type=str, + default="你好,欢迎使用语音合成服务", + help='A sentence to be synthesized') + self.parser.add_argument( + '--spk_id', type=int, default=0, help='Speaker id') + self.parser.add_argument( + '--speed', type=float, default=1.0, help='Audio speed') + self.parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + self.parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + self.parser.add_argument( + '--output', + type=str, + default="./output.wav", + help='Synthesized audio file') + + # Request and response + def tts_client(self, args): + """ Request and response + Args: + input: A sentence to be synthesized + outfile: Synthetic audio file + """ + url = 'http://' + args.server_ip + ":" + str( + args.port) + '/paddlespeech/tts' + request = { + "text": args.input, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": args.output + } + + response = requests.post(url, json.dumps(request)) + response_dict = response.json() + print(response_dict["message"]) + wav_base64 = response_dict["result"]["audio"] + + audio_data_byte = base64.b64decode(wav_base64) + # from byte + samples, sample_rate = soundfile.read( + io.BytesIO(audio_data_byte), dtype='float32') + + # transform audio + outfile = args.output + if outfile.endswith(".wav"): + soundfile.write(outfile, samples, sample_rate) + elif outfile.endswith(".pcm"): + temp_wav = str(random.getrandbits(128)) + ".wav" + soundfile.write(temp_wav, samples, sample_rate) + wav2pcm(temp_wav, outfile, data_type=np.int16) + os.system("rm %s" % (temp_wav)) + else: + print("The format for saving audio only supports wav or pcm") + + return len(samples), sample_rate + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + st = time.time() + try: + samples_length, sample_rate = self.tts_client(args) + time_consume = time.time() - st + print("Save synthesized audio successfully on %s." % (args.output)) + print("Inference time: %f s." % (time_consume)) + except: + print("Failed to synthesized audio.") + + +@cli_client_register( + name='paddlespeech_client.asr', description='visit asr service') +class ASRClientExecutor(BaseExecutor): + def __init__(self): + super().__init__() + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--input', + type=str, + default="./paddlespeech/server/tests/16_audio.wav", + help='Audio file to be recognized') + self.parser.add_argument( + '--sample_rate', type=int, default=16000, help='audio sample rate') + self.parser.add_argument( + '--lang', type=str, default="zh_cn", help='language') + self.parser.add_argument( + '--audio_format', type=str, default="wav", help='audio format') + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + url = 'http://' + args.server_ip + ":" + str( + args.port) + '/paddlespeech/asr' + audio = wav2base64(args.input) + data = { + "audio": audio, + "audio_format": args.audio_format, + "sample_rate": args.sample_rate, + "lang": args.lang, + } + time_start = time.time() + try: + r = requests.post(url=url, data=json.dumps(data)) + # ending Timestamp + time_end = time.time() + print(r.json()) + print('time cost', time_end - time_start, 's') + except: + print("Failed to speech recognition.") diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py new file mode 100644 index 0000000000000000000000000000000000000000..367375fc74d04ed77f807b85becef4d6a9f645b3 --- /dev/null +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from typing import List + +import uvicorn +from fastapi import FastAPI + +from ..executor import BaseExecutor +from ..util import cli_server_register +from paddlespeech.server.engine.engine_factory import EngineFactory +from paddlespeech.server.restful.api import setup_router +from paddlespeech.server.utils.config import get_config + +__all__ = ['ServerExecutor'] + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") + + +@cli_server_register( + name='paddlespeech_server.start', description='Start the service') +class ServerExecutor(BaseExecutor): + def __init__(self): + super().__init__() + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default="./conf/application.yaml") + + self.parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + + def init(self, config) -> bool: + """system initialization + + Args: + config (CfgNode): config object + + Returns: + bool: + """ + # init api + api_list = list(config.engine_backend) + api_router = setup_router(api_list) + app.include_router(api_router) + + # init engine + engine_pool = [] + for engine in config.engine_backend: + engine_pool.append(EngineFactory.get_engine(engine_name=engine)) + if not engine_pool[-1].init( + config_file=config.engine_backend[engine]): + return False + + return True + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + config = get_config(args.config_file) + + if self.init(config): + uvicorn.run(app, host=config.host, port=config.port, debug=True) diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67cc3b34be7cbc46f445cc8edac218707edd2acd --- /dev/null +++ b/paddlespeech/server/conf/application.yaml @@ -0,0 +1,17 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################## +# SERVER SETTING # +################################################################## +host: '0.0.0.0' +port: 8090 + +################################################################## +# CONFIG FILE # +################################################################## +# add engine type (Options: asr, tts) and config file here. + +engine_backend: + asr: 'conf/asr/asr.yaml' + tts: 'conf/tts/tts_pd.yaml' + diff --git a/paddlespeech/server/conf/asr/asr.yaml b/paddlespeech/server/conf/asr/asr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c3b0a67e30273681fe765fc2e827f86a21ac380 --- /dev/null +++ b/paddlespeech/server/conf/asr/asr.yaml @@ -0,0 +1,7 @@ +model: 'conformer_wenetspeech' +lang: 'zh' +sample_rate: 16000 +cfg_path: +ckpt_path: +decode_method: 'attention_rescoring' +force_yes: False diff --git a/paddlespeech/server/conf/tts/tts.yaml b/paddlespeech/server/conf/tts/tts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0e128eaee0c14783d23867563ee0275fbceef1b --- /dev/null +++ b/paddlespeech/server/conf/tts/tts.yaml @@ -0,0 +1,32 @@ +# This is the parameter configuration file for TTS server. + +################################################################## +# ACOUSTIC MODEL SETTING # +# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc', +# 'fastspeech2_ljspeech', 'fastspeech2_aishell3', +# 'fastspeech2_vctk'] +################################################################## +am: 'fastspeech2_csmsc' +am_config: +am_ckpt: +am_stat: +phones_dict: +tones_dict: +speaker_dict: +spk_id: 0 + +################################################################## +# VOCODER SETTING # +# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', +# 'pwgan_vctk', 'mb_melgan_csmsc'] +################################################################## +voc: 'pwgan_csmsc' +voc_config: +voc_ckpt: +voc_stat: + +################################################################## +# OTHERS # +################################################################## +lang: 'zh' +device: paddle.get_device() \ No newline at end of file diff --git a/paddlespeech/server/conf/tts/tts_pd.yaml b/paddlespeech/server/conf/tts/tts_pd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c268c6a336bb21be7879980cb3cb3c59611d64cd --- /dev/null +++ b/paddlespeech/server/conf/tts/tts_pd.yaml @@ -0,0 +1,41 @@ +# This is the parameter configuration file for TTS server. +# These are the static models that support paddle inference. + +################################################################## +# ACOUSTIC MODEL SETTING # +# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'] +################################################################## +am: 'fastspeech2_csmsc' +am_model: # the pdmodel file of am static model +am_params: # the pdiparams file of am static model +am_sample_rate: 24000 +phones_dict: +tones_dict: +speaker_dict: +spk_id: 0 + +am_predictor_conf: + use_gpu: True + enable_mkldnn: True + switch_ir_optim: True + + +################################################################## +# VOCODER SETTING # +# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc'] +################################################################## +voc: 'pwgan_csmsc' +voc_model: # the pdmodel file of vocoder static model +voc_params: # the pdiparams file of vocoder static model +voc_sample_rate: 24000 + +voc_predictor_conf: + use_gpu: True + enable_mkldnn: True + switch_ir_optim: True + +################################################################## +# OTHERS # +################################################################## +lang: 'zh' +device: paddle.get_device() diff --git a/paddlespeech/server/download.py b/paddlespeech/server/download.py new file mode 100644 index 0000000000000000000000000000000000000000..ea943dd8745c17cacdb0575a8552ba1a75ab4a7c --- /dev/null +++ b/paddlespeech/server/download.py @@ -0,0 +1,329 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import os +import os.path as osp +import shutil +import subprocess +import tarfile +import time +import zipfile + +import requests +from tqdm import tqdm + +from paddlespeech.cli.log import logger + +__all__ = ['get_path_from_url'] + +DOWNLOAD_RETRY_LIMIT = 3 + + +def _is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def _map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def _get_unique_endpoints(trainer_endpoints): + # Sorting is to avoid different environmental variables for each card + trainer_endpoints.sort() + ips = set() + unique_endpoints = set() + for endpoint in trainer_endpoints: + ip = endpoint.split(":")[0] + if ip in ips: + continue + ips.add(ip) + unique_endpoints.add(endpoint) + logger.info("unique_endpoints {}".format(unique_endpoints)) + return unique_endpoints + + +def get_path_from_url(url, + root_dir, + md5sum=None, + check_exist=True, + decompress=True, + method='get'): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + decompress (bool): decompress zip or tar file. Default is `True` + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + Returns: + str: a local path to save downloaded models & weights & datasets. + """ + + from paddle.fluid.dygraph.parallel import ParallelEnv + + assert _is_url(url), "downloading from {} not a url".format(url) + # parse path after download to decompress under root_dir + fullpath = _map_path(url, root_dir) + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download + # data, and the same ip will only download data once. + unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + logger.info("Found {}".format(fullpath)) + else: + if ParallelEnv().current_endpoint in unique_endpoints: + fullpath = _download(url, root_dir, md5sum, method=method) + else: + while not os.path.exists(fullpath): + time.sleep(1) + + if ParallelEnv().current_endpoint in unique_endpoints: + if decompress and (tarfile.is_tarfile(fullpath) or + zipfile.is_zipfile(fullpath)): + fullpath = _decompress(fullpath) + + return fullpath + + +def _get_download(url, fullname): + # using requests.get method + fname = osp.basename(fullname) + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info("Downloading {} from {} failed with exception {}".format( + fname, url, str(e))) + return False + + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _wget_download(url, fullname): + # using wget to download url + tmp_fullname = fullname + "_tmp" + # –user-agent + command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, + url) + subprc = subprocess.Popen( + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _ = subprc.communicate() + + if subprc.returncode != 0: + raise RuntimeError( + '{} failed. Please make sure `wget` is installed or {} exists'. + format(command, url)) + + shutil.move(tmp_fullname, fullname) + + return fullname + + +_download_methods = { + 'get': _get_download, + 'wget': _wget_download, +} + + +def _download(url, path, md5sum=None, method='get'): + """ + Download from url, save to path. + url (str): download url + path (str): download to given path + md5sum (str): md5 sum of download package + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + """ + assert method in _download_methods, 'make sure `{}` implemented'.format( + method) + + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + logger.info("Downloading {} from {}".format(fname, url)) + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + if not _download_methods[method](url, fullname): + time.sleep(1) + continue + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + + if tarfile.is_tarfile(fname): + uncompressed_path = _uncompress_file_tar(fname) + elif zipfile.is_zipfile(fname): + uncompressed_path = _uncompress_file_zip(fname) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + return uncompressed_path + + +def _uncompress_file_zip(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _uncompress_file_tar(filepath, mode="r:*"): + files = tarfile.open(filepath, mode) + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _is_a_single_file(file_list): + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list): + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True diff --git a/paddlespeech/server/engine/__init__.py b/paddlespeech/server/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/__init__.py b/paddlespeech/server/engine/asr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/python/__init__.py b/paddlespeech/server/engine/asr/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/asr/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b11549832a3429f10b4af9361f67f779adda7baa --- /dev/null +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import List +from typing import Optional +from typing import Union + +import librosa +import paddle +import soundfile + +from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.log import logger +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.config import get_config + +__all__ = ['ASREngine'] + + +class ASRServerExecutor(ASRExecutor): + def __init__(self): + super().__init__() + pass + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error("please input --sr 8000 or --sr 16000") + return False + + logger.info("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + except Exception as e: + logger.exception(e) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + + logger.info("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + self.change_format = True + else: + logger.info("The audio file format is right") + self.change_format = False + + return True + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + + # Get the object for feature extraction + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + audio, _ = self.collate_fn_test.process_utterance( + audio_file=audio_file, transcript=" ") + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + # vocab_list = collate_fn_test.vocab_list + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + logger.info("get the preprocess conf") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + logger.info("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample(audio, audio_sample_rate, + self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.info(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.info(f"audio feat shape: {audio.shape}") + + else: + raise Exception("wrong type") + + +class ASREngine(BaseEngine): + """ASR server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(ASREngine, self).__init__() + + def init(self, config_file: str) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.input = None + self.output = None + self.executor = ASRServerExecutor() + + try: + self.config = get_config(config_file) + paddle.set_device(paddle.get_device()) + self.executor._init_from_path( + self.config.model, self.config.lang, self.config.sample_rate, + self.config.cfg_path, self.config.decode_method, + self.config.ckpt_path) + except: + logger.info("Initialize ASR server engine Failed.") + return False + + logger.info("Initialize ASR server engine successfully.") + return True + + def run(self, audio_data): + """engine run + + Args: + audio_data (bytes): base64.b64decode + """ + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start run asr engine") + self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) + self.executor.infer(self.config.model) + self.output = self.executor.postprocess() # Retrieve result of asr. + else: + logger.info("file check failed!") + self.output = None + + def postprocess(self): + """postprocess + """ + return self.output diff --git a/paddlespeech/server/engine/base_engine.py b/paddlespeech/server/engine/base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc20209479ea7e033943b799a7e161ac21e3b35 --- /dev/null +++ b/paddlespeech/server/engine/base_engine.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any +from typing import List +from typing import Union + +from pattern_singleton import Singleton + +__all__ = ['BaseEngine'] + + +class BaseEngine(metaclass=Singleton): + """ + An base engine class + """ + + def __init__(self): + self._inputs = dict() + self._outputs = dict() + + def init(self, *args, **kwargs): + """ + init the engine + + Returns: + bool: true or false + """ + pass + + def postprocess(self, *args, **kwargs) -> Union[str, os.PathLike]: + """ + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. + """ + pass + + def run(self, *args, **kwargs) -> Union[str, os.PathLike]: + """ + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. + """ + pass diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..79319fd9566efc9916e9a658dca627592143760c --- /dev/null +++ b/paddlespeech/server/engine/engine_factory.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Text + +from paddlespeech.server.engine.asr.python.asr_engine import ASREngine +#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine +from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine + + +__all__ = ['EngineFactory'] + + +class EngineFactory(object): + @staticmethod + def get_engine(engine_name: Text): + if engine_name == 'asr': + return ASREngine() + elif engine_name == 'tts': + return TTSEngine() + else: + return None diff --git a/paddlespeech/server/engine/tts/__init__.py b/paddlespeech/server/engine/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/tts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/paddleinference/__init__.py b/paddlespeech/server/engine/tts/paddleinference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/tts/paddleinference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..7679b02f03b2b5bf6f52482ae3a926f1081f3d65 --- /dev/null +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -0,0 +1,482 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import os +from typing import Optional + +import librosa +import numpy as np +import paddle +import soundfile as sf +from scipy.io import wavfile + +from paddlespeech.cli.log import logger +from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import change_speed +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.exception import ServerBaseException +from paddlespeech.server.utils.paddle_predictor import init_predictor +from paddlespeech.server.utils.paddle_predictor import run_model +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend + +__all__ = ['TTSEngine'] + +# Static model applied on paddle inference +pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip', + 'md5': + 'f10cbdedf47dc7a9668d2264494e1823', + 'model': + 'speedyspeech_csmsc.pdmodel', + 'params': + 'speedyspeech_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + 'sample_rate': + 24000, + }, + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip', + 'md5': + '9788cd9745e14c7a5d12d32670b2a5a7', + 'model': + 'fastspeech2_csmsc.pdmodel', + 'params': + 'fastspeech2_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + # pwgan + "pwgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip', + 'md5': + 'e3504aed9c5a290be12d1347836d2742', + 'model': + 'pwgan_csmsc.pdmodel', + 'params': + 'pwgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip', + 'md5': + 'ac6eee94ba483421d750433f4c3b8d36', + 'model': + 'mb_melgan_csmsc.pdmodel', + 'params': + 'mb_melgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip', + 'md5': + '7edd8c436b3a5546b3a7cb8cff9d5a0c', + 'model': + 'hifigan_csmsc.pdmodel', + 'params': + 'hifigan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, +} + + +class TTSServerExecutor(TTSExecutor): + def __init__(self): + super().__init__() + pass + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_model: Optional[os.PathLike]=None, + am_params: Optional[os.PathLike]=None, + am_sample_rate: int=24000, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='pwgan_csmsc', + voc_model: Optional[os.PathLike]=None, + voc_params: Optional[os.PathLike]=None, + voc_sample_rate: int=24000, + lang: str='zh', + am_predictor_conf: dict=None, + voc_predictor_conf: dict=None, ): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'): + logger.info('Models had been initialized.') + return + # am + am_tag = am + '-' + lang + if am_model is None or am_params is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_model = os.path.join(am_res_path, + pretrained_models[am_tag]['model']) + self.am_params = os.path.join(am_res_path, + pretrained_models[am_tag]['params']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + self.am_sample_rate = pretrained_models[am_tag]['sample_rate'] + + logger.info(am_res_path) + logger.info(self.am_model) + logger.info(self.am_params) + else: + self.am_model = os.path.abspath(am_model) + self.am_params = os.path.abspath(am_params) + self.phones_dict = os.path.abspath(phones_dict) + self.am_sample_rate = am_sample_rate + self.am_res_path = os.path.dirname(os.path.abspath(self.am_model)) + print("self.phones_dict:", self.phones_dict) + + # for speedyspeech + self.tones_dict = None + if 'tones_dict' in pretrained_models[am_tag]: + self.tones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['tones_dict']) + if tones_dict: + self.tones_dict = tones_dict + + # for multi speaker fastspeech2 + self.speaker_dict = None + if 'speaker_dict' in pretrained_models[am_tag]: + self.speaker_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['speaker_dict']) + if speaker_dict: + self.speaker_dict = speaker_dict + + # voc + voc_tag = voc + '-' + lang + if voc_model is None or voc_params is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_model = os.path.join(voc_res_path, + pretrained_models[voc_tag]['model']) + self.voc_params = os.path.join(voc_res_path, + pretrained_models[voc_tag]['params']) + self.voc_sample_rate = pretrained_models[voc_tag]['sample_rate'] + logger.info(voc_res_path) + logger.info(self.voc_model) + logger.info(self.voc_params) + else: + self.voc_model = os.path.abspath(voc_model) + self.voc_params = os.path.abspath(voc_params) + self.voc_sample_rate = voc_sample_rate + self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_model)) + + assert ( + self.voc_sample_rate == self.am_sample_rate + ), "The sample rate of AM and Vocoder model are different, please check model." + + # Init body. + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + tone_size = None + if self.tones_dict: + with open(self.tones_dict, "r") as f: + tone_id = [line.strip().split() for line in f.readlines()] + tone_size = len(tone_id) + print("tone_size:", tone_size) + + spk_num = None + if self.speaker_dict: + with open(self.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + print("frontend done!") + + # am predictor + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # voc predictor + self.voc_predictor_conf = voc_predictor_conf + self.voc_predictor = init_predictor( + model_file=self.voc_model, + params_file=self.voc_params, + predictor_conf=self.voc_predictor_conf) + + @paddle.no_grad() + def infer(self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0): + """ + Model inference and result stored in self.output. + """ + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + get_tone_ids = False + merge_sentences = False + if am_name == 'speedyspeech': + get_tone_ids = True + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + text, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + # am + if am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + am_result = run_model( + self.am_predictor, + [part_phone_ids.numpy(), part_tone_ids.numpy()]) + mel = am_result[0] + + # fastspeech2 + else: + # multi speaker do not have static model + if am_dataset in {"aishell3", "vctk"}: + pass + else: + am_result = run_model(self.am_predictor, + [part_phone_ids.numpy()]) + mel = am_result[0] + # voc + voc_result = run_model(self.voc_predictor, [mel]) + wav = voc_result[0] + wav = paddle.to_tensor(wav) + + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + self._outputs['wav'] = wav_all + + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + """Initialize TTS server engine + """ + super(TTSEngine, self).__init__() + + def init(self, config_file: str) -> bool: + self.executor = TTSServerExecutor() + + try: + self.config = get_config(config_file) + + self.executor._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) + + except: + logger.info("Initialize TTS server engine Failed.") + return False + + logger.info("Initialize TTS server engine successfully.") + return True + + def postprocess(self, + wav, + original_fs: int, + target_fs: int=16000, + volume: float=1.0, + speed: float=1.0, + audio_path: str=None): + """Post-processing operations, including speech, volume, sample rate, save audio file + + Args: + wav (numpy(float)): Synthesized audio sample points + original_fs (int): original audio sample rate + target_fs (int): target audio sample rate + volume (float): target volume + speed (float): target speed + + Raises: + ServerBaseException: Throws an exception if the change speed unsuccessfully. + + Returns: + target_fs: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + # transform sample_rate + if target_fs == 0 or target_fs > original_fs: + target_fs = original_fs + wav_tar_fs = wav + else: + wav_tar_fs = librosa.resample( + np.squeeze(wav), original_fs, target_fs) + + # transform volume + wav_vol = wav_tar_fs * volume + + # transform speed + try: # windows not support soxbindings + wav_speed = change_speed(wav_vol, speed, target_fs) + except: + raise ServerBaseException( + ErrorCode.SERVER_INTERNAL_ERR, + "Transform speed failed. Can not install soxbindings on your system. \ + You need to set speed value 1.0.") + + # wav to base64 + buf = io.BytesIO() + wavfile.write(buf, target_fs, wav_speed) + base64_bytes = base64.b64encode(buf.read()) + wav_base64 = base64_bytes.decode('utf-8') + + # save audio + if audio_path is not None and audio_path.endswith(".wav"): + sf.write(audio_path, wav_speed, target_fs) + elif audio_path is not None and audio_path.endswith(".pcm"): + wav_norm = wav_speed * (32767 / max(0.001, + np.max(np.abs(wav_speed)))) + with open(audio_path, "wb") as f: + f.write(wav_norm.astype(np.int16)) + + return target_fs, wav_base64 + + def run(self, + sentence: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + save_path: str=None): + """get the result of the server response + + Args: + sentence (str): sentence to be synthesized + spk_id (int, optional): speaker id. Defaults to 0. + speed (float, optional): audio speed, 0 < speed <=3.0. Defaults to 1.0. + volume (float, optional): The volume relative to the audio synthesized by the model, + 0 < volume <=3.0. Defaults to 1.0. + sample_rate (int, optional): Set the sample rate of the synthesized audio. + 0 represents the sample rate for model synthesis. Defaults to 0. + save_path (str, optional): The save path of the synthesized audio. Defaults to None. + + Raises: + ServerBaseException: Throws an exception if tts inference unsuccessfully. + ServerBaseException: Throws an exception if postprocess unsuccessfully. + + Returns: + lang: model language + target_sample_rate: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + lang = self.config.lang + + try: + self.executor.infer( + text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts infer failed.") + + try: + target_sample_rate, wav_base64 = self.postprocess( + wav=self.executor._outputs['wav'].numpy(), + original_fs=self.executor.am_sample_rate, + target_fs=sample_rate, + volume=volume, + speed=speed, + audio_path=save_path) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts postprocess failed.") + + return lang, target_sample_rate, wav_base64 diff --git a/paddlespeech/server/engine/tts/python/__init__.py b/paddlespeech/server/engine/tts/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/engine/tts/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..e11cfb1d1671ae26816a8974c1d55bf0d39e3c06 --- /dev/null +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io + +import librosa +import numpy as np +import paddle +import soundfile as sf +from scipy.io import wavfile + +from paddlespeech.cli.log import logger +from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import change_speed +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.exception import ServerBaseException + +__all__ = ['TTSEngine'] + + +class TTSServerExecutor(TTSExecutor): + def __init__(self): + super().__init__() + pass + + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super(TTSEngine, self).__init__() + + def init(self, config_file: str) -> bool: + self.executor = TTSServerExecutor() + + try: + self.config = get_config(config_file) + paddle.set_device(self.config.device) + + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except: + logger.info("Initialize TTS server engine Failed.") + return False + + logger.info("Initialize TTS server engine successfully.") + return True + + def postprocess(self, + wav, + original_fs: int, + target_fs: int=16000, + volume: float=1.0, + speed: float=1.0, + audio_path: str=None): + """Post-processing operations, including speech, volume, sample rate, save audio file + + Args: + wav (numpy(float)): Synthesized audio sample points + original_fs (int): original audio sample rate + target_fs (int): target audio sample rate + volume (float): target volume + speed (float): target speed + + Raises: + ServerBaseException: Throws an exception if the change speed unsuccessfully. + + Returns: + target_fs: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + # transform sample_rate + if target_fs == 0 or target_fs > original_fs: + target_fs = original_fs + wav_tar_fs = wav + else: + wav_tar_fs = librosa.resample( + np.squeeze(wav), original_fs, target_fs) + + # transform volume + wav_vol = wav_tar_fs * volume + + # transform speed + try: # windows not support soxbindings + wav_speed = change_speed(wav_vol, speed, target_fs) + except: + raise ServerBaseException( + ErrorCode.SERVER_INTERNAL_ERR, + "Can not install soxbindings on your system.") + + # wav to base64 + buf = io.BytesIO() + wavfile.write(buf, target_fs, wav_speed) + base64_bytes = base64.b64encode(buf.read()) + wav_base64 = base64_bytes.decode('utf-8') + + # save audio + if audio_path is not None and audio_path.endswith(".wav"): + sf.write(audio_path, wav_speed, target_fs) + elif audio_path is not None and audio_path.endswith(".pcm"): + wav_norm = wav_speed * (32767 / max(0.001, + np.max(np.abs(wav_speed)))) + with open(audio_path, "wb") as f: + f.write(wav_norm.astype(np.int16)) + + return target_fs, wav_base64 + + def run(self, + sentence: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + save_path: str=None): + """ run include inference and postprocess. + + Args: + sentence (str): text to be synthesized + spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0. + speed (float, optional): speed. Defaults to 1.0. + volume (float, optional): volume. Defaults to 1.0. + sample_rate (int, optional): target sample rate for synthesized audio, + 0 means the same as the model sampling rate. Defaults to 0. + save_path (str, optional): The save path of the synthesized audio. + None means do not save audio. Defaults to None. + + Raises: + ServerBaseException: Throws an exception if tts inference unsuccessfully. + ServerBaseException: Throws an exception if postprocess unsuccessfully. + + Returns: + lang: model language + target_sample_rate: target sample rate for synthesized audio. + wav_base64: The base64 format of the synthesized audio. + """ + + lang = self.config.lang + + try: + self.executor.infer( + text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts infer failed.") + + try: + target_sample_rate, wav_base64 = self.postprocess( + wav=self.executor._outputs['wav'].numpy(), + original_fs=self.executor.am_config.fs, + target_fs=sample_rate, + volume=volume, + speed=speed, + audio_path=save_path) + except: + raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, + "tts postprocess failed.") + + return lang, target_sample_rate, wav_base64 diff --git a/paddlespeech/server/entry.py b/paddlespeech/server/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..f817321d06544db844fc6000616e70307a548379 --- /dev/null +++ b/paddlespeech/server/entry.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from collections import defaultdict + +__all__ = ['server_commands', 'client_commands'] + + +def _CommandDict(): + return defaultdict(_CommandDict) + + +def server_execute(): + com = server_commands + idx = 0 + for _argv in (['paddlespeech_server'] + sys.argv[1:]): + if _argv not in com: + break + idx += 1 + com = com[_argv] + + # The method 'execute' of a command instance returns 'True' for a success + # while 'False' for a failure. Here converts this result into a exit status + # in bash: 0 for a success and 1 for a failure. + status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 + return status + + +def client_execute(): + com = client_commands + idx = 0 + for _argv in (['paddlespeech_client'] + sys.argv[1:]): + if _argv not in com: + break + idx += 1 + com = com[_argv] + + # The method 'execute' of a command instance returns 'True' for a success + # while 'False' for a failure. Here converts this result into a exit status + # in bash: 0 for a success and 1 for a failure. + status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 + return status + + +server_commands = _CommandDict() +client_commands = _CommandDict() diff --git a/paddlespeech/server/executor.py b/paddlespeech/server/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..192e1f1750ad0f89fbba3452f99f190cd96f8121 --- /dev/null +++ b/paddlespeech/server/executor.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from abc import ABC +from abc import abstractmethod +from typing import List + +class BaseExecutor(ABC): + """ + An abstract executor of paddlespeech server tasks. + """ + + def __init__(self): + self.parser = argparse.ArgumentParser() + + @abstractmethod + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. This method can only be accessed by a command line such as `paddlespeech asr`. + + Args: + argv (List[str]): Arguments from command line. + + Returns: + int: Result of the command execution. `True` for a success and `False` for a failure. + """ + pass diff --git a/paddlespeech/server/restful/__init__.py b/paddlespeech/server/restful/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/restful/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce5cad0ef7b3db718dda0b897ad416ed3ba825b --- /dev/null +++ b/paddlespeech/server/restful/api.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from fastapi import APIRouter + +from paddlespeech.server.restful.asr_api import router as asr_router +from paddlespeech.server.restful.tts_api import router as tts_router + +_router = APIRouter() + + +def setup_router(api_list: List): + + for api_name in api_list: + if api_name == 'asr': + _router.include_router(asr_router) + elif api_name == 'tts': + _router.include_router(tts_router) + else: + pass + + return _router diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdb2f41fa420160501f4d35f600aa40a1cc089f --- /dev/null +++ b/paddlespeech/server/restful/asr_api.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import traceback +from typing import Union +from fastapi import APIRouter + +from paddlespeech.server.engine.asr.python.asr_engine import ASREngine +from paddlespeech.server.restful.request import ASRRequest +from paddlespeech.server.restful.response import ASRResponse +from paddlespeech.server.restful.response import ErrorResponse +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.errors import failed_response +from paddlespeech.server.utils.exception import ServerBaseException + +router = APIRouter() + + +@router.get('/paddlespeech/asr/help') +def help(): + """help + + Returns: + json: [description] + """ + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "asr server", + "input": "base64 string of wavfile", + "output": "transcription" + } + } + return response + + +@router.post( + "/paddlespeech/asr", response_model=Union[ASRResponse, ErrorResponse]) +def asr(request_body: ASRRequest): + """asr api + + Args: + request_body (ASRRequest): [description] + + Returns: + json: [description] + """ + try: + # single + audio_data = base64.b64decode(request_body.audio) + asr_engine = ASREngine() + asr_engine.run(audio_data) + asr_results = asr_engine.postprocess() + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "transcription": asr_results + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py new file mode 100644 index 0000000000000000000000000000000000000000..2be5f0e546dee6c1c042820ac1a3838a446e23ea --- /dev/null +++ b/paddlespeech/server/restful/request.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List +from typing import Optional + +from pydantic import BaseModel + +__all__ = ['ASRRequest', 'TTSRequest'] + + +#****************************************************************************************/ +#************************************ ASR request ***************************************/ +#****************************************************************************************/ +class ASRRequest(BaseModel): + """ + request body example + { + "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", + "audio_format": "wav", + "sample_rate": 16000, + "lang": "zh_cn", + "punc":false + } + """ + audio: str + audio_format: str + sample_rate: int + lang: str + punc: Optional[bool] = None + + +#****************************************************************************************/ +#************************************ TTS request ***************************************/ +#****************************************************************************************/ +class TTSRequest(BaseModel): + """TTS request + + request body example + { + "text": "你好,欢迎使用百度飞桨语音合成服务。", + "spk_id": 0, + "speed": 1.0, + "volume": 1.0, + "sample_rate": 0, + "tts_audio_path": "./tts.wav" + } + + """ + + text: str + spk_id: int = 0 + speed: float = 1.0 + volume: float = 1.0 + sample_rate: int = 0 + save_path: str = None diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5e395ba6914482e320d13abf2744e2fef71ec0 --- /dev/null +++ b/paddlespeech/server/restful/response.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List +from typing import Optional + +from pydantic import BaseModel + +__all__ = ['ASRResponse', 'TTSResponse'] + + +class Message(BaseModel): + description: str + + +#****************************************************************************************/ +#************************************ ASR response **************************************/ +#****************************************************************************************/ +class AsrResult(BaseModel): + transcription: str + + +class ASRResponse(BaseModel): + """ + response example + { + "success": true, + "code": 0, + "message": { + "description": "success" + }, + "result": { + "transcription": "你好,飞桨" + } + } + """ + success: bool + code: int + message: Message + result: AsrResult + + +#****************************************************************************************/ +#************************************ TTS response **************************************/ +#****************************************************************************************/ +class TTSResult(BaseModel): + lang: str = "zh" + sample_rate: int + spk_id: int = 0 + speed: float = 1.0 + volume: float = 1.0 + save_path: str = None + audio: str + + +class TTSResponse(BaseModel): + """ + response example + { + "success": true, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "lang": "zh", + "sample_rate": 24000, + "speed": 1.0, + "volume": 1.0, + "audio": "LTI1OTIuNjI1OTUwMzQsOTk2OS41NDk4...", + "save_path": "./tts.wav" + } + } + """ + success: bool + code: int + message: Message + result: TTSResult + + +#****************************************************************************************/ +#********************************** Error response **************************************/ +#****************************************************************************************/ +class ErrorResponse(BaseModel): + """ + response example + { + "success": false, + "code": 0, + "message": { + "description": "Unknown error occurred." + } + } + """ + success: bool + code: int + message: Message diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fa1d42c4db0e822ab2d545ad69225ebb382222 --- /dev/null +++ b/paddlespeech/server/restful/tts_api.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback +from typing import Union + +from fastapi import APIRouter + +from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine +from paddlespeech.server.restful.request import TTSRequest +from paddlespeech.server.restful.response import ErrorResponse +from paddlespeech.server.restful.response import TTSResponse +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.errors import failed_response +from paddlespeech.server.utils.exception import ServerBaseException + +router = APIRouter() + + +@router.get('/paddlespeech/tts/help') +def help(): + """help + + Returns: + json: [description] + """ + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "tts server", + "text": "sentence to be synthesized", + "audio": "the base64 of audio" + } + } + return response + + +@router.post( + "/paddlespeech/tts", response_model=Union[TTSResponse, ErrorResponse]) +def tts(request_body: TTSRequest): + """tts api + + Args: + request_body (TTSRequest): [description] + + Returns: + json: [description] + """ + # json to dict + item_dict = request_body.dict() + sentence = item_dict['text'] + spk_id = item_dict['spk_id'] + speed = item_dict['speed'] + volume = item_dict['volume'] + sample_rate = item_dict['sample_rate'] + save_path = item_dict['save_path'] + + # Check parameters + if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ + sample_rate not in [0, 16000, 8000] or \ + (save_path is not None and not save_path.endswith("pcm") and not save_path.endswith("wav")): + return failed_response(ErrorCode.SERVER_PARAM_ERR) + + # single + tts_engine = TTSEngine() + + # run + try: + lang, target_sample_rate, wav_base64 = tts_engine.run( + sentence, spk_id, speed, volume, sample_rate, save_path) + + response = { + "success": True, + "code": 200, + "message": { + "description": "success." + }, + "result": { + "lang": lang, + "spk_id": spk_id, + "speed": speed, + "volume": volume, + "sample_rate": target_sample_rate, + "save_path": save_path, + "audio": wav_base64 + } + } + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/paddlespeech/server/tests/16_audio.wav b/paddlespeech/server/tests/16_audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..3cfa5074efaea618684e3ca7b497a2b1f33fa7e4 Binary files /dev/null and b/paddlespeech/server/tests/16_audio.wav differ diff --git a/paddlespeech/server/tests/http_client.py b/paddlespeech/server/tests/http_client.py new file mode 100644 index 0000000000000000000000000000000000000000..14adb5741989790140fa509bb4e6eeca1b48546f --- /dev/null +++ b/paddlespeech/server/tests/http_client.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the +import requests +import json +import time +import base64 +import io + + +def readwav2base64(wav_file): + """ + read wave file and covert to base64 string + """ + with open(wav_file, 'rb') as f: + base64_bytes = base64.b64encode(f.read()) + base64_string = base64_bytes.decode('utf-8') + return base64_string + + +def main(): + """ + main func + """ + url = "http://127.0.0.1:8090/paddlespeech/asr" + + # start Timestamp + time_start=time.time() + + test_audio_dir = "./16_audio.wav" + audio = readwav2base64(test_audio_dir) + + data = { + "audio": audio, + "audio_format": "wav", + "sample_rate": 16000, + "lang": "zh_cn", + } + + r = requests.post(url=url, data=json.dumps(data)) + + # ending Timestamp + time_end=time.time() + print('time cost',time_end - time_start, 's') + + print(r.json()) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/server/tests/tts/test_client.py b/paddlespeech/server/tests/tts/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..65f4ccfece121f5ab472fe3a2e9e2f34244136b9 --- /dev/null +++ b/paddlespeech/server/tests/tts/test_client.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import io +import json +import os +import random +import time + +import numpy as np +import requests +import soundfile + +from paddlespeech.server.utils.audio_process import wav2pcm + +# Request and response +def tts_client(args): + """ Request and response + Args: + text: A sentence to be synthesized + outfile: Synthetic audio file + """ + url = 'http://127.0.0.1:8090/paddlespeech/tts' + request = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": args.output + } + + response = requests.post(url, json.dumps(request)) + response_dict = response.json() + wav_base64 = response_dict["result"]["audio"] + + audio_data_byte = base64.b64decode(wav_base64) + # from byte + samples, sample_rate = soundfile.read( + io.BytesIO(audio_data_byte), dtype='float32') + + # transform audio + outfile = args.output + if outfile.endswith(".wav"): + soundfile.write(outfile, samples, sample_rate) + elif outfile.endswith(".pcm"): + temp_wav = str(random.getrandbits(128)) + ".wav" + soundfile.write(temp_wav, samples, sample_rate) + wav2pcm(temp_wav, outfile, data_type=np.int16) + os.system("rm %s" % (temp_wav)) + else: + print("The format for saving audio only supports wav or pcm") + + return len(samples), sample_rate + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="你好,欢迎使用语音合成服务", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + '--output', + type=str, + default="./out.wav", + help='Synthesized audio file') + args = parser.parse_args() + + st = time.time() + try: + samples_length, sample_rate = tts_client(args) + time_consume = time.time() - st + duration = samples_length / sample_rate + rtf = time_consume / duration + print("Synthesized audio successfully.") + print("Inference time: %f" % (time_consume)) + print("The duration of synthesized audio: %f" % (duration)) + print("The RTF is: %f" % (rtf)) + except: + print("Failed to synthesized audio.") diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py new file mode 100644 index 0000000000000000000000000000000000000000..58e86b27775c9e60f842d7f1b0459d23aa9b67a4 --- /dev/null +++ b/paddlespeech/server/util.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import inspect +import json +import os +import tarfile +import threading +import time +import uuid +import zipfile +from typing import Any +from typing import Dict + +import paddle +import requests +import yaml +from paddle.framework import load + +import paddleaudio +from . import download +from .. import __version__ +from .entry import client_commands +from .entry import server_commands + +requests.adapters.DEFAULT_RETRIES = 3 + +__all__ = [ + 'cli_server_register', + 'get_server_command', + 'cli_client_register', + 'get_client_command', + 'download_and_decompress', + 'load_state_dict_from_url', + 'stats_wrapper', +] + + +def cli_server_register(name: str, description: str='') -> Any: + def _warpper(command): + items = name.split('.') + + com = server_commands + for item in items: + com = com[item] + com['_entry'] = command + if description: + com['_description'] = description + return command + + return _warpper + + +def get_server_command(name: str) -> Any: + items = name.split('.') + com = server_commands + for item in items: + com = com[item] + + return com['_entry'] + + +def cli_client_register(name: str, description: str='') -> Any: + def _warpper(command): + items = name.split('.') + + com = client_commands + for item in items: + com = com[item] + com['_entry'] = command + if description: + com['_description'] = description + return command + + return _warpper + + +def get_client_command(name: str) -> Any: + items = name.split('.') + com = client_commands + for item in items: + com = com[item] + + return com['_entry'] + + +def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: + file_dir = os.path.dirname(filepath) + is_zip_file = False + if tarfile.is_tarfile(filepath): + files = tarfile.open(filepath, "r:*") + file_list = files.getnames() + elif zipfile.is_zipfile(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + is_zip_file = True + else: + return file_dir + + if download._is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + elif download._is_a_single_dir(file_list): + if is_zip_file: + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + else: + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + files.close() + return uncompressed_path + + +def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike: + """ + Download archieves and decompress to specific path. + """ + if not os.path.isdir(path): + os.makedirs(path) + + assert 'url' in archive and 'md5' in archive, \ + 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys())) + + filepath = os.path.join(path, os.path.basename(archive['url'])) + if os.path.isfile(filepath) and download._md5check(filepath, + archive['md5']): + uncompress_path = _get_uncompress_path(filepath) + if not os.path.isdir(uncompress_path): + download._decompress(filepath) + else: + StatsWorker( + task='download', + version=__version__, + extra_info={ + 'download_url': archive['url'], + 'paddle_version': paddle.__version__ + }).start() + uncompress_path = download.get_path_from_url(archive['url'], path, + archive['md5']) + + return uncompress_path + + +def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike: + """ + Download and load a state dict from url + """ + if not os.path.isdir(path): + os.makedirs(path) + + download.get_path_from_url(url, path, md5) + return load(os.path.join(path, os.path.basename(url))) + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_paddlespcceh_home(): + if 'PPSPEECH_HOME' in os.environ: + home_path = os.environ['PPSPEECH_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + raise RuntimeError( + 'The environment variable PPSPEECH_HOME {} is not a directory.'. + format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddlespeech') + + +def _get_sub_home(directory): + home = os.path.join(_get_paddlespcceh_home(), directory) + if not os.path.exists(home): + os.makedirs(home) + return home + + +PPSPEECH_HOME = _get_paddlespcceh_home() +MODEL_HOME = _get_sub_home('models') +CONF_HOME = _get_sub_home('conf') + + +def _md5(text: str): + '''Calculate the md5 value of the input text.''' + md5code = hashlib.md5(text.encode()) + return md5code.hexdigest() + + +class ConfigCache: + def __init__(self): + self._data = {} + self._initialize() + self.file = os.path.join(CONF_HOME, 'cache.yaml') + if not os.path.exists(self.file): + self.flush() + return + + with open(self.file, 'r') as file: + try: + cfg = yaml.load(file, Loader=yaml.FullLoader) + self._data.update(cfg) + except: + self.flush() + + @property + def cache_info(self): + return self._data['cache_info'] + + def _initialize(self): + # Set default configuration values. + cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time())) + self._data['cache_info'] = cache_info + + def flush(self): + '''Flush the current configuration into the configuration file.''' + with open(self.file, 'w') as file: + cfg = json.loads(json.dumps(self._data)) + yaml.dump(cfg, file) + + +stats_api = "http://paddlepaddle.org.cn/paddlehub/stat" +cache_info = ConfigCache().cache_info + + +class StatsWorker(threading.Thread): + def __init__(self, + task="asr", + model=None, + version=__version__, + extra_info={}): + threading.Thread.__init__(self) + self._task = task + self._model = model + self._version = version + self._extra_info = extra_info + + def run(self): + params = { + 'task': self._task, + 'version': self._version, + 'from': 'ppspeech' + } + if self._model: + params['model'] = self._model + + self._extra_info.update({ + 'cache_info': cache_info, + }) + params.update({"extra": json.dumps(self._extra_info)}) + + try: + requests.get(stats_api, params) + except Exception: + pass + + return + + +def _note_one_stat(cls_name, params={}): + task = cls_name.replace('Executor', '').lower() # XXExecutor + extra_info = { + 'paddle_version': paddle.__version__, + } + + if 'model' in params: + model = params['model'] + else: + model = None + + if 'audio_file' in params: + try: + _, sr = paddleaudio.load(params['audio_file']) + except Exception: + sr = -1 + + if task == 'asr': + extra_info.update({ + 'lang': params['lang'], + 'inp_sr': sr, + 'model_sr': params['sample_rate'], + }) + elif task == 'st': + extra_info.update({ + 'lang': + params['src_lang'] + '-' + params['tgt_lang'], + 'inp_sr': + sr, + 'model_sr': + params['sample_rate'], + }) + elif task == 'tts': + model = params['am'] + extra_info.update({ + 'lang': params['lang'], + 'vocoder': params['voc'], + }) + elif task == 'cls': + extra_info.update({ + 'inp_sr': sr, + }) + elif task == 'text': + extra_info.update({ + 'sub_task': params['task'], + 'lang': params['lang'], + }) + else: + return + + StatsWorker( + task=task, + model=model, + version=__version__, + extra_info=extra_info, ).start() + + +def _parse_args(func, *args, **kwargs): + # FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations) + argspec = inspect.getfullargspec(func) + + keys = argspec[0] + if keys[0] == 'self': # Remove self pointer. + keys = keys[1:] + + default_values = argspec[3] + values = [None] * (len(keys) - len(default_values)) + values.extend(list(default_values)) + params = dict(zip(keys, values)) + + for idx, v in enumerate(args): + params[keys[idx]] = v + for k, v in kwargs.items(): + params[k] = v + + return params + + +def stats_wrapper(executor_func): + def _warpper(self, *args, **kwargs): + try: + _note_one_stat( + type(self).__name__, _parse_args(executor_func, *args, + **kwargs)) + except Exception: + pass + return executor_func(self, *args, **kwargs) + + return _warpper diff --git a/paddlespeech/server/utils/__init__.py b/paddlespeech/server/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47 --- /dev/null +++ b/paddlespeech/server/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbb495a67ffcb54444fd44173571eccb02addef --- /dev/null +++ b/paddlespeech/server/utils/audio_process.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import wave + +import numpy as np + +from paddlespeech.cli.log import logger + + +def wav2pcm(wavfile, pcmfile, data_type=np.int16): + """ Save the wav file as a pcm file + + Args: + wavfile (str): wav file path + pcmfile (str): pcm file save path + data_type (type, optional): pcm sample type. Defaults to np.int16. + """ + with open(wavfile, "rb") as f: + f.seek(0) + f.read(44) + data = np.fromfile(f, dtype=data_type) + data.tofile(pcmfile) + + +def pcm2wav(pcm_file, wav_file, channels=1, bits=16, sample_rate=16000): + """Save the pcm file as a wav file + + Args: + pcm_file (str): pcm file path + wav_file (str): wav file save path + channels (int, optional): audio channel. Defaults to 1. + bits (int, optional): Bit depth. Defaults to 16. + sample_rate (int, optional): sample rate. Defaults to 16000. + """ + pcmf = open(pcm_file, 'rb') + pcmdata = pcmf.read() + pcmf.close() + + if bits % 8 != 0: + logger.error("bits % 8 must == 0. now bits:" + str(bits)) + + wavfile = wave.open(wav_file, 'wb') + wavfile.setnchannels(channels) + wavfile.setsampwidth(bits // 8) + wavfile.setframerate(sample_rate) + wavfile.writeframes(pcmdata) + wavfile.close() + + +def change_speed(sample_raw, speed_rate, sample_rate): + """Change the audio speed by linear interpolation. + Note that this is an in-place transformation. + :param speed_rate: Rate of speed change: + speed_rate > 1.0, speed up the audio; + speed_rate = 1.0, unchanged; + speed_rate < 1.0, slow down the audio; + speed_rate <= 0.0, not allowed, raise ValueError. + :type speed_rate: float + :raises ValueError: If speed_rate <= 0.0. + """ + if speed_rate == 1.0: + return sample_raw + if speed_rate <= 0: + raise ValueError("speed_rate should be greater than zero.") + + # numpy + # old_length = self._samples.shape[0] + # new_length = int(old_length / speed_rate) + # old_indices = np.arange(old_length) + # new_indices = np.linspace(start=0, stop=old_length, num=new_length) + # self._samples = np.interp(new_indices, old_indices, self._samples) + + # sox, slow + try: + import soxbindings as sox + except ImportError: + try: + from paddlespeech.s2t.utils import dynamic_pip_install + package = "sox" + dynamic_pip_install.install(package) + package = "soxbindings" + dynamic_pip_install.install(package) + import soxbindings as sox + except Exception: + raise RuntimeError("Can not install soxbindings on your system.") + + tfm = sox.Transformer() + tfm.set_globals(multithread=False) + tfm.tempo(speed_rate) + sample_speed = tfm.build_array( + input_array=sample_raw, + sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy() + + return sample_speed diff --git a/paddlespeech/server/utils/config.py b/paddlespeech/server/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8c75f536f5de654f1a09fa82187cfef4ef442e90 --- /dev/null +++ b/paddlespeech/server/utils/config.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml +from yacs.config import CfgNode + + +def get_config(config_file: str): + """[summary] + + Args: + config_file (str): config_file + + Returns: + CfgNode: + """ + with open(config_file, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + return config diff --git a/paddlespeech/server/utils/errors.py b/paddlespeech/server/utils/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..17ff75512cd447648ecedf9238809a42743b708c --- /dev/null +++ b/paddlespeech/server/utils/errors.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from enum import IntEnum + +from fastapi import Response + + +class ErrorCode(IntEnum): + SERVER_OK = 200 # success. + + SERVER_PARAM_ERR = 400 # Input parameters are not valid. + SERVER_TASK_NOT_EXIST = 404 # Task is not exist. + + SERVER_INTERNAL_ERR = 500 # Internal error. + SERVER_NETWORK_ERR = 502 # Network exception. + SERVER_UNKOWN_ERR = 509 # Unknown error occurred. + + +ErrorMsg = { + ErrorCode.SERVER_OK: "success.", + ErrorCode.SERVER_PARAM_ERR: "Input parameters are not valid.", + ErrorCode.SERVER_TASK_NOT_EXIST: "Task is not exist.", + ErrorCode.SERVER_INTERNAL_ERR: "Internal error.", + ErrorCode.SERVER_NETWORK_ERR: "Network exception.", + ErrorCode.SERVER_UNKOWN_ERR: "Unknown error occurred." +} + + +def failed_response(code, msg=""): + """Interface call failure response + + Args: + code (int): error code number + msg (str, optional): Interface call failure information. Defaults to "". + + Returns: + Response (json): failure json information. + """ + + if not msg: + msg = ErrorMsg.get(code, "Unknown error occurred.") + + res = {"success": False, "code": int(code), "message": {"description": msg}} + + return Response(content=json.dumps(res), media_type="application/json") diff --git a/paddlespeech/server/utils/exception.py b/paddlespeech/server/utils/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..58ea777ca520c489c7c88090f6c758cad3bac1df --- /dev/null +++ b/paddlespeech/server/utils/exception.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import traceback + +from paddlespeech.server.utils.errors import ErrorMsg + + +class ServerBaseException(Exception): + """ Server Base exception + """ + + def __init__(self, error_code, msg=None): + #if msg: + #log.error(msg) + msg = msg if msg else ErrorMsg.get(error_code, "") + super(ServerBaseException, self).__init__(error_code, msg) + self.error_code = error_code + self.msg = msg + traceback.print_exc() diff --git a/paddlespeech/server/utils/log.py b/paddlespeech/server/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..8644064c73ef407476e7870e65d1149019762723 --- /dev/null +++ b/paddlespeech/server/utils/log.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import logging + +__all__ = [ + 'logger', +] + + +class Logger(object): + def __init__(self, name: str=None): + name = 'PaddleSpeech' if not name else name + self.logger = logging.getLogger(name) + + log_config = { + 'DEBUG': 10, + 'INFO': 20, + 'TRAIN': 21, + 'EVAL': 22, + 'WARNING': 30, + 'ERROR': 40, + 'CRITICAL': 50, + 'EXCEPTION': 100, + } + for key, level in log_config.items(): + logging.addLevelName(level, key) + if key == 'EXCEPTION': + self.__dict__[key.lower()] = self.logger.exception + else: + self.__dict__[key.lower()] = functools.partial(self.__call__, + level) + + self.format = logging.Formatter( + fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s') + + self.handler = logging.StreamHandler() + self.handler.setFormatter(self.format) + + self.logger.addHandler(self.handler) + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = False + + def __call__(self, log_level: str, msg: str): + self.logger.log(log_level, msg) + + +logger = Logger() diff --git a/paddlespeech/server/utils/paddle_predictor.py b/paddlespeech/server/utils/paddle_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..f910161b88896e054439c855da3efcdad10b21ae --- /dev/null +++ b/paddlespeech/server/utils/paddle_predictor.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List +from typing import Optional + +from paddle.inference import Config +from paddle.inference import create_predictor + + +def init_predictor(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, + params_file: Optional[os.PathLike]=None, + predictor_conf: dict=None): + """Create predictor with Paddle inference + + Args: + model_dir (Optional[os.PathLike], optional): The path of the static model saved in the model layer. Defaults to None. + model_file (Optional[os.PathLike], optional): *.pdmodel file path. Defaults to None. + params_file (Optional[os.PathLike], optional): *.pdiparams file path.. Defaults to None. + predictor_conf (dict, optional): The configuration parameters of predictor. Defaults to None. + + Returns: + predictor (PaddleInferPredictor): created predictor + """ + + if model_dir is not None: + config = Config(args.model_dir) + else: + config = Config(model_file, params_file) + + config.enable_memory_optim() + if predictor_conf["use_gpu"]: + config.enable_use_gpu(1000, 0) + if predictor_conf["enable_mkldnn"]: + config.enable_mkldnn() + if predictor_conf["switch_ir_optim"]: + config.switch_ir_optim() + + predictor = create_predictor(config) + + return predictor + + +def run_model(predictor, input: List) -> List: + """ run predictor + + Args: + predictor: paddle inference predictor + input (list): The input of predictor + + Returns: + list: result list + """ + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_handle = predictor.get_input_handle(name) + input_handle.copy_from_cpu(input[i]) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_handle = predictor.get_output_handle(name) + output_data = output_handle.copy_to_cpu() + results.append(output_data) + + return results diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e9104fa2d56283c48304d4676fae19e8dccd1ba5 --- /dev/null +++ b/paddlespeech/server/utils/util.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the +import base64 + + +def wav2base64(wav_file: str): + """ + read wave file and covert to base64 string + """ + with open(wav_file, 'rb') as f: + base64_bytes = base64.b64encode(f.read()) + base64_string = base64_bytes.decode('utf-8') + return base64_string + + +def base64towav(base64_string: str): + pass + + +def self_check(): + """ self check resource + """ + return True diff --git a/setup.py b/setup.py index b85ce9245da39dd47628db6a930c7e282edf9920..9bb11d0dd5829b9f551fca46867e752d71b03c43 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,9 @@ requirements = { "visualdl", "webrtcvad", "yacs~=0.1.8", + # fastapi server + "fastapi", + "uvicorn", ], "develop": [ "ConfigArgParse", @@ -253,7 +256,11 @@ setup_info = dict( 'Programming Language :: Python :: 3.9', ], entry_points={ - 'console_scripts': ['paddlespeech=paddlespeech.cli.entry:_execute'] + 'console_scripts': [ + 'paddlespeech=paddlespeech.cli.entry:_execute', + 'paddlespeech_server=paddlespeech.server.entry:server_execute', + 'paddlespeech_client=paddlespeech.server.entry:client_execute' + ] }) setup(**setup_info)