diff --git a/demos/streaming_tts_serving_fastdeploy/README.md b/demos/streaming_tts_serving_fastdeploy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3e983a06da4b18f663967ac982b780fe44619af1 --- /dev/null +++ b/demos/streaming_tts_serving_fastdeploy/README.md @@ -0,0 +1,67 @@ +([简体中文](./README_cn.md)|English) + +# Streaming Speech Synthesis Service + +## Introduction +This demo is an implementation of starting the streaming speech synthesis service and accessing the service. + +`Server` must be started in the docker, while `Client` does not have to be in the docker. + +**The streaming_tts_serving under the path of this article ($PWD) contains the configuration and code of the model, which needs to be mapped to the docker for use.** + +## Usage +### 1. Server +#### 1.1 Docker + +```bash +docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09 +docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09 +docker exec -it -u root fastdeploy bash +``` + +#### 1.2 Installation(inside the docker) +```bash +apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip +pip3 install paddlespeech +export LC_ALL="zh_CN.UTF-8" +export LANG="zh_CN.UTF-8" +export LANGUAGE="zh_CN:zh:en_US:en" +``` + +#### 1.3 Download models(inside the docker) +```bash +cd /models/streaming_tts_serving/1 +wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip +wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip +unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip +unzip mb_melgan_csmsc_onnx_0.2.0.zip +``` +**For the convenience of users, we recommend that you use the command `docker -v` to map $PWD (streaming_tts_service and the configuration and code of the model contained therein) to the docker path `/models`. You can also use other methods, but regardless of which method you use, the final model directory and structure in the docker are shown in the following figure.** + +

+ +

+ +#### 1.4 Start the server(inside the docker) + +```bash +fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving +``` +Arguments: + - `model-repository`(required): Path of model storage. + - `model-control-mode`(required): The mode of loading the model. At present, you can use 'explicit'. + - `load-model`(required): Name of the model to be loaded. + - `http-port`(optional): Port for http service. Default: `8000`. This is not used in our example. + - `grpc-port`(optional): Port for grpc service. Default: `8001`. + - `metrics-port`(optional): Port for metrics service. Default: `8002`. This is not used in our example. + +### 2. Client +#### 2.1 Installation +```bash +pip3 install tritonclient[all] +``` + +#### 2.2 Send request +```bash +python3 /models/streaming_tts_serving/stream_client.py +``` diff --git a/demos/streaming_tts_serving_fastdeploy/README_cn.md b/demos/streaming_tts_serving_fastdeploy/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..7edd32830c99b9e6ce3289774bdde7b9777ae891 --- /dev/null +++ b/demos/streaming_tts_serving_fastdeploy/README_cn.md @@ -0,0 +1,67 @@ +(简体中文|[English](./README.md)) + +# 流式语音合成服务 + +## 介绍 + +本文介绍了使用FastDeploy搭建流式语音合成服务的方法。 + +`服务端`必须在docker内启动,而`客户端`不是必须在docker容器内. + +**本文所在路径`($PWD)下的streaming_tts_serving里包含模型的配置和代码`(服务端会加载模型和代码以启动服务),需要将其映射到docker中使用。** + +## 使用 +### 1. 服务端 +#### 1.1 Docker +```bash +docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09 +docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09 +docker exec -it -u root fastdeploy bash +``` + +#### 1.2 安装(在docker内) +```bash +apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip +pip3 install paddlespeech +export LC_ALL="zh_CN.UTF-8" +export LANG="zh_CN.UTF-8" +export LANGUAGE="zh_CN:zh:en_US:en" +``` + +#### 1.3 下载模型(在docker内) +```bash +cd /models/streaming_tts_serving/1 +wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip +wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip +unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip +unzip mb_melgan_csmsc_onnx_0.2.0.zip +``` +**为了方便用户使用,我们推荐用户使用1.1中的`docker -v`命令将`$PWD(streaming_tts_serving及里面包含的模型的配置和代码)映射到了docker内的/models路径`,用户也可以使用其他办法,但无论使用哪种方法,最终在docker内的模型目录及结构如下图所示。** + +

+ +

+ +#### 1.4 启动服务端(在docker内) +```bash +fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving +``` + +参数: + - `model-repository`(required): 整套模型streaming_tts_serving存放的路径. + - `model-control-mode`(required): 模型加载的方式,现阶段, 使用'explicit'即可. + - `load-model`(required): 需要加载的模型的名称. + - `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口. + - `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`. + - `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口. + +### 2. 客户端 +#### 2.1 安装 +```bash +pip3 install tritonclient[all] +``` + +#### 2.2 发送请求 +```bash +python3 /models/streaming_tts_serving/stream_client.py +``` diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py new file mode 100644 index 0000000000000000000000000000000000000000..46473fdb2a9321a3a0a7e9aec2f2290098cb056e --- /dev/null +++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py @@ -0,0 +1,289 @@ +import codecs +import json +import math +import sys +import threading +import time + +import numpy as np +import onnxruntime as ort +import triton_python_backend_utils as pb_utils + +from paddlespeech.server.utils.util import denorm +from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend.zh_frontend import Frontend + +voc_block = 36 +voc_pad = 14 +am_block = 72 +am_pad = 12 +voc_upsample = 300 + +# 模型路径 +dir_name = "/models/streaming_tts_serving/1/" +phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt" +am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy" + +onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx" +onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx" +onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx" +onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx" + +frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None) +am_mu, am_std = np.load(am_stat_path) + +# 用CPU推理 +providers = ['CPUExecutionProvider'] + +# 配置ort session +sess_options = ort.SessionOptions() + +# 创建session +am_encoder_infer_sess = ort.InferenceSession( + onnx_am_encoder, providers=providers, sess_options=sess_options) +am_decoder_sess = ort.InferenceSession( + onnx_am_decoder, providers=providers, sess_options=sess_options) +am_postnet_sess = ort.InferenceSession( + onnx_am_postnet, providers=providers, sess_options=sess_options) +voc_melgan_sess = ort.InferenceSession( + onnx_voc_melgan, providers=providers, sess_options=sess_options) + + +def depadding(data, chunk_num, chunk_id, block, pad, upsample): + """ + Streaming inference removes the result of pad inference + """ + front_pad = min(chunk_id * block, pad) + # first chunk + if chunk_id == 0: + data = data[:block * upsample] + # last chunk + elif chunk_id == chunk_num - 1: + data = data[front_pad * upsample:] + # middle chunk + else: + data = data[front_pad * upsample:(front_pad + block) * upsample] + + return data + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach()) + print(sys.getdefaultencoding()) + # You must parse model_config. JSON string is not parsed here + self.model_config = model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + using_decoupled = pb_utils.using_decoupled_model_transaction_policy( + model_config) + + if not using_decoupled: + raise pb_utils.TritonModelException( + """the model `{}` can generate any number of responses per request, + enable decoupled transaction policy in model configuration to + serve this model""".format(args['model_name'])) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("input:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("output:", self.output_names) + + # To keep track of response threads so that we can delay + # the finalizing the model until all response threads + # have completed. + self.inflight_thread_count = 0 + self.inflight_thread_count_lck = threading.Lock() + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + # This model does not support batching, so 'request_count' should always + # be 1. + if len(requests) != 1: + raise pb_utils.TritonModelException("unsupported batch size " + len( + requests)) + + input_data = [] + for idx in range(len(self.input_names)): + data = pb_utils.get_input_tensor_by_name(requests[0], + self.input_names[idx]) + data = data.as_numpy() + data = data[0].decode('utf-8') + input_data.append(data) + text = input_data[0] + + # Start a separate thread to send the responses for the request. The + # sending back the responses is delegated to this thread. + thread = threading.Thread( + target=self.response_thread, + args=(requests[0].get_response_sender(), text)) + thread.daemon = True + with self.inflight_thread_count_lck: + self.inflight_thread_count += 1 + + thread.start() + # Unlike in non-decoupled model transaction policy, execute function + # here returns no response. A return from this function only notifies + # Triton that the model instance is ready to receive another request. As + # we are not waiting for the response thread to complete here, it is + # possible that at any give time the model may be processing multiple + # requests. Depending upon the request workload, this may lead to a lot + # of requests being processed by a single model instance at a time. In + # real-world models, the developer should be mindful of when to return + # from execute and be willing to accept next request. + return None + + def response_thread(self, response_sender, text): + input_ids = frontend.get_input_ids( + text, merge_sentences=False, get_tone_ids=False) + phone_ids = input_ids["phone_ids"] + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i].numpy() + voc_chunk_id = 0 + + orig_hs = am_encoder_infer_sess.run( + None, input_feed={'text': part_phone_ids}) + orig_hs = orig_hs[0] + + # streaming voc chunk info + mel_len = orig_hs.shape[1] + voc_chunk_num = math.ceil(mel_len / voc_block) + start = 0 + end = min(voc_block + voc_pad, mel_len) + + # streaming am + hss = get_chunks(orig_hs, am_block, am_pad, "am") + am_chunk_num = len(hss) + for i, hs in enumerate(hss): + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': hs}) + am_postnet_output = am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0][0] + + sub_mel = denorm(normalized_mel, am_mu, am_std) + sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, + 1) + + if i == 0: + mel_streaming = sub_mel + else: + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) + + # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): + voc_chunk = mel_streaming[start:end, :] + + sub_wav = voc_melgan_sess.run( + output_names=None, input_feed={'logmel': voc_chunk}) + sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id, + voc_block, voc_pad, voc_upsample) + + output_np = np.array(sub_wav, dtype=self.output_dtype[0]) + out_tensor1 = pb_utils.Tensor(self.output_names[0], + output_np) + + status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1 + output_status = np.array( + [status], dtype=self.output_dtype[1]) + out_tensor2 = pb_utils.Tensor(self.output_names[1], + output_status) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor1, out_tensor2]) + + #yield sub_wav + response_sender.send(inference_response) + + voc_chunk_id += 1 + start = max(0, voc_chunk_id * voc_block - voc_pad) + end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len) + + # We must close the response sender to indicate to Triton that we are + # done sending responses for the corresponding request. We can't use the + # response sender after closing it. The response sender is closed by + # setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL. + response_sender.send( + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + with self.inflight_thread_count_lck: + self.inflight_thread_count -= 1 + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is OPTIONAL. This function allows + the model to perform any necessary clean ups before exit. + Here we will wait for all response threads to complete sending + responses. + """ + print('Finalize invoked') + + inflight_threads = True + cycles = 0 + logging_time_sec = 5 + sleep_time_sec = 0.1 + cycle_to_log = (logging_time_sec / sleep_time_sec) + while inflight_threads: + with self.inflight_thread_count_lck: + inflight_threads = (self.inflight_thread_count != 0) + if (cycles % cycle_to_log == 0): + print( + f"Waiting for {self.inflight_thread_count} response threads to complete..." + ) + if inflight_threads: + time.sleep(sleep_time_sec) + cycles += 1 + + print('Finalize complete...') diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e63721d1ca15183a80453a3d39c8a9f4c9262687 --- /dev/null +++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt @@ -0,0 +1,33 @@ +name: "streaming_tts_serving" +backend: "python" +max_batch_size: 0 +model_transaction_policy { + decoupled: True +} +input [ + { + name: "INPUT_0" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +output [ + { + name: "OUTPUT_0" + data_type: TYPE_FP32 + dims: [ -1, 1 ] + }, + { + name: "status" + data_type: TYPE_BOOL + dims: [ 1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f120b7d1c49813a1c248e48505828a2b6a638c --- /dev/null +++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +import argparse +import queue +import sys +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +FLAGS = None + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +# Define the callback function. Note the last two parameters should be +# result and error. InferenceServerClient would povide the results of an +# inference as grpcclient.InferResult in result. For successful +# inference, error will be None, otherwise it will be an object of +# tritonclientutils.InferenceServerException holding the error details +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +def async_stream_send(triton_client, values, request_id, model_name): + + infer_inputs = [] + outputs = [] + for idx, data in enumerate(values): + data = np.array([data.encode('utf-8')], dtype=np.object_) + infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES") + infer_input.set_data_from_numpy(data) + infer_inputs.append(infer_input) + + outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0')) + # Issue the asynchronous sequence inference. + triton_client.async_stream_infer( + model_name=model_name, + inputs=infer_inputs, + outputs=outputs, + request_id=request_id) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument( + '-u', + '--url', + type=str, + required=False, + default='localhost:8001', + help='Inference server URL and it gRPC port. Default is localhost:8001.') + + FLAGS = parser.parse_args() + + # We use custom "sequence" models which take 1 input + # value. The output is the accumulated value of the inputs. See + # src/custom/sequence. + model_name = "streaming_tts_serving" + + values = ["哈哈哈哈"] + + request_id = "0" + + string_result0_list = [] + + user_data = UserData() + + # It is advisable to use client object within with..as clause + # when sending streaming requests. This ensures the client + # is closed when the block inside with exits. + with grpcclient.InferenceServerClient( + url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: + try: + # Establish stream + triton_client.start_stream(callback=partial(callback, user_data)) + # Now send the inference sequences... + async_stream_send(triton_client, values, request_id, model_name) + except InferenceServerException as error: + print(error) + sys.exit(1) + + # Retrieve results... + recv_count = 0 + result_dict = {} + status = True + while True: + data_item = user_data._completed_requests.get() + if type(data_item) == InferenceServerException: + raise data_item + else: + this_id = data_item.get_response().id + if this_id not in result_dict.keys(): + result_dict[this_id] = [] + result_dict[this_id].append((recv_count, data_item)) + sub_wav = data_item.as_numpy('OUTPUT_0') + status = data_item.as_numpy('status') + print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape) + print('status = ', status) + if status[0] == 1: + break + recv_count += 1 + + print("PASS: stream_client") diff --git a/demos/streaming_tts_serving_fastdeploy/tree.png b/demos/streaming_tts_serving_fastdeploy/tree.png new file mode 100644 index 0000000000000000000000000000000000000000..b8d61686aa76f0e1270172b55b0e0ae64560d45e Binary files /dev/null and b/demos/streaming_tts_serving_fastdeploy/tree.png differ