提交 ab044887 编写于 作者: L lym0302

update server cli, test=doc

上级 c64282e7
...@@ -15,6 +15,17 @@ You can choose one way from easy, meduim and hard to install paddlespeech. ...@@ -15,6 +15,17 @@ You can choose one way from easy, meduim and hard to install paddlespeech.
### 2. Prepare config File ### 2. Prepare config File
The configuration file contains the service-related configuration files and the model configuration related to the voice tasks contained in the service. They are all under the `conf` folder. The configuration file contains the service-related configuration files and the model configuration related to the voice tasks contained in the service. They are all under the `conf` folder.
**Note: The configuration of `engine_backend` in `application.yaml` represents all speech tasks included in the started service. **
If the service you want to start contains only a certain speech task, then you need to comment out the speech tasks that do not need to be included. For example, if you only want to use the speech recognition (ASR) service, then you can comment out the speech synthesis (TTS) service, as in the following example:
```bash
engine_backend:
asr: 'conf/asr/asr.yaml'
#tts: 'conf/tts/tts.yaml'
```
**Note: The configuration file of `engine_backend` in `application.yaml` needs to match the configuration type of `engine_type`. **
When the configuration file of `engine_backend` is `XXX.yaml`, the configuration type of `engine_type` needs to be set to `python`; when the configuration file of `engine_backend` is `XXX_pd.yaml`, the configuration of `engine_type` needs to be set type is `inference`;
The input of ASR client demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. The input of ASR client demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
Here are sample files for thisASR client demo that can be downloaded: Here are sample files for thisASR client demo that can be downloaded:
...@@ -76,6 +87,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -76,6 +87,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
### 4. ASR Client Usage ### 4. ASR Client Usage
**Note:** The response time will be slightly longer when using the client for the first time
- Command Line (Recommended) - Command Line (Recommended)
``` ```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
...@@ -122,6 +134,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -122,6 +134,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
``` ```
### 5. TTS Client Usage ### 5. TTS Client Usage
**Note:** The response time will be slightly longer when using the client for the first time
- Command Line (Recommended) - Command Line (Recommended)
```bash ```bash
paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav
...@@ -147,8 +160,6 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -147,8 +160,6 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
[2022-02-23 15:20:37,875] [ INFO] - Save synthesized audio successfully on output.wav. [2022-02-23 15:20:37,875] [ INFO] - Save synthesized audio successfully on output.wav.
[2022-02-23 15:20:37,875] [ INFO] - Audio duration: 3.612500 s. [2022-02-23 15:20:37,875] [ INFO] - Audio duration: 3.612500 s.
[2022-02-23 15:20:37,875] [ INFO] - Response time: 0.348050 s. [2022-02-23 15:20:37,875] [ INFO] - Response time: 0.348050 s.
[2022-02-23 15:20:37,875] [ INFO] - RTF: 0.096346
``` ```
...@@ -174,51 +185,13 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -174,51 +185,13 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Save synthesized audio successfully on ./output.wav. Save synthesized audio successfully on ./output.wav.
Audio duration: 3.612500 s. Audio duration: 3.612500 s.
Response time: 0.388317 s. Response time: 0.388317 s.
RTF: 0.107493
``` ```
## Pretrained Models ## Models supported by the service
### ASR model ### ASR model
Here is a list of [ASR pretrained models](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_recognition/README.md#4pretrained-models) released by PaddleSpeech, both command line and python interfaces are available: Get all models supported by the ASR service via `paddlespeech_server stats --task asr`, where static models can be used for paddle inference inference.
| Model | Language | Sample Rate
| :--- | :---: | :---: |
| conformer_wenetspeech| zh| 16000
| transformer_librispeech| en| 16000
### TTS model ### TTS model
Here is a list of [TTS pretrained models](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/text_to_speech/README.md#4-pretrained-models) released by PaddleSpeech, both command line and python interfaces are available: Get all models supported by the TTS service via `paddlespeech_server stats --task tts`, where static models can be used for paddle inference inference.
- Acoustic model
| Model | Language
| :--- | :---: |
| speedyspeech_csmsc| zh
| fastspeech2_csmsc| zh
| fastspeech2_aishell3| zh
| fastspeech2_ljspeech| en
| fastspeech2_vctk| en
- Vocoder
| Model | Language
| :--- | :---: |
| pwgan_csmsc| zh
| pwgan_aishell3| zh
| pwgan_ljspeech| en
| pwgan_vctk| en
| mb_melgan_csmsc| zh
Here is a list of **TTS pretrained static models** released by PaddleSpeech, both command line and python interfaces are available:
- Acoustic model
| Model | Language
| :--- | :---: |
| speedyspeech_csmsc| zh
| fastspeech2_csmsc| zh
- Vocoder
| Model | Language
| :--- | :---: |
| pwgan_csmsc| zh
| mb_melgan_csmsc| zh
| hifigan_csmsc| zh
...@@ -14,6 +14,15 @@ ...@@ -14,6 +14,15 @@
### 2. 准备配置文件 ### 2. 准备配置文件
配置文件包含服务相关的配置文件和服务中包含的语音任务相关的模型配置。 它们都在 `conf` 文件夹下。 配置文件包含服务相关的配置文件和服务中包含的语音任务相关的模型配置。 它们都在 `conf` 文件夹下。
**注意:`application.yaml` 中 `engine_backend` 的配置表示启动的服务中包含的所有语音任务。**
如果你想启动的服务中只包含某项语音任务,那么你需要注释掉不需要包含的语音任务。例如你只想使用语音识别(ASR)服务,那么你可以将语音合成(TTS)服务注释掉,如下示例:
```bash
engine_backend:
asr: 'conf/asr/asr.yaml'
#tts: 'conf/tts/tts.yaml'
```
**注意:`application.yaml` 中 `engine_backend` 的配置文件需要和 `engine_type` 的配置类型匹配。**
`engine_backend` 的配置文件为`XXX.yaml`时,需要设置`engine_type`的配置类型为`python`;当`engine_backend` 的配置文件为`XXX_pd.yaml`时,需要设置`engine_type`的配置类型为`inference`;
这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
...@@ -75,6 +84,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -75,6 +84,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
``` ```
### 4. ASR客户端使用方法 ### 4. ASR客户端使用方法
**注意:**初次使用客户端时响应时间会略长
- 命令行 (推荐使用) - 命令行 (推荐使用)
``` ```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
...@@ -123,6 +133,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -123,6 +133,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
``` ```
### 5. TTS客户端使用方法 ### 5. TTS客户端使用方法
**注意:**初次使用客户端时响应时间会略长
```bash ```bash
paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav
``` ```
...@@ -148,7 +159,6 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -148,7 +159,6 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
[2022-02-23 15:20:37,875] [ INFO] - Save synthesized audio successfully on output.wav. [2022-02-23 15:20:37,875] [ INFO] - Save synthesized audio successfully on output.wav.
[2022-02-23 15:20:37,875] [ INFO] - Audio duration: 3.612500 s. [2022-02-23 15:20:37,875] [ INFO] - Audio duration: 3.612500 s.
[2022-02-23 15:20:37,875] [ INFO] - Response time: 0.348050 s. [2022-02-23 15:20:37,875] [ INFO] - Response time: 0.348050 s.
[2022-02-23 15:20:37,875] [ INFO] - RTF: 0.096346
``` ```
- Python API - Python API
...@@ -173,50 +183,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -173,50 +183,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Save synthesized audio successfully on ./output.wav. Save synthesized audio successfully on ./output.wav.
Audio duration: 3.612500 s. Audio duration: 3.612500 s.
Response time: 0.388317 s. Response time: 0.388317 s.
RTF: 0.107493
```
```
## 服务支持的模型
## Pretrained Models ### ASR支持的模型
### ASR model 通过 `paddlespeech_server stats --task asr` 获取ASR服务支持的所有模型,其中静态模型可用于 paddle inference 推理。
下面是PaddleSpeech发布的[ASR预训练模型](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_recognition/README.md#4pretrained-models)列表,命令行和python接口均可用:
### TTS支持的模型
| Model | Language | Sample Rate 通过 `paddlespeech_server stats --task tts` 获取TTS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。
| :--- | :---: | :---: |
| conformer_wenetspeech| zh| 16000
| transformer_librispeech| en| 16000
### TTS model
下面是PaddleSpeech发布的 [TTS预训练模型](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/text_to_speech/README.md#4-pretrained-models) 列表,命令行和python接口均可用:
- Acoustic model
| Model | Language
| :--- | :---: |
| speedyspeech_csmsc| zh
| fastspeech2_csmsc| zh
| fastspeech2_aishell3| zh
| fastspeech2_ljspeech| en
| fastspeech2_vctk| en
- Vocoder
| Model | Language
| :--- | :---: |
| pwgan_csmsc| zh
| pwgan_aishell3| zh
| pwgan_ljspeech| en
| pwgan_vctk| en
| mb_melgan_csmsc| zh
下面是PaddleSpeech发布的 **TTS预训练静态模型** 列表,命令行和python接口均可用:
- Acoustic model
| Model | Language
| :--- | :---: |
| speedyspeech_csmsc| zh
| fastspeech2_csmsc| zh
- Vocoder
| Model | Language
| :--- | :---: |
| pwgan_csmsc| zh
| mb_melgan_csmsc| zh
| hifigan_csmsc| zh
...@@ -3,23 +3,25 @@ ...@@ -3,23 +3,25 @@
################################################################## ##################################################################
# SERVER SETTING # # SERVER SETTING #
################################################################## ##################################################################
host: '0.0.0.0' host: '127.0.0.1'
port: 8090 port: 8090
################################################################## ##################################################################
# CONFIG FILE # # CONFIG FILE #
################################################################## ##################################################################
# add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend:
asr: 'conf/asr/asr.yaml'
tts: 'conf/tts/tts.yaml'
# The engine_type of speech task needs to keep the same type as the config file of speech task. # The engine_type of speech task needs to keep the same type as the config file of speech task.
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml' # E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml' # E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
# #
# add engine type (Options: python, inference) # add engine type (Options: python, inference)
engine_type: engine_type:
asr: 'inference' asr: 'python'
tts: 'inference' tts: 'python'
# add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend:
asr: 'conf/asr/asr_pd.yaml'
tts: 'conf/tts/tts_pd.yaml'
...@@ -5,4 +5,4 @@ cfg_path: # [optional] ...@@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: True force_yes: True
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
...@@ -15,9 +15,10 @@ decode_method: ...@@ -15,9 +15,10 @@ decode_method:
force_yes: True force_yes: True
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
......
...@@ -29,4 +29,4 @@ voc_stat: ...@@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
...@@ -15,9 +15,10 @@ speaker_dict: ...@@ -15,9 +15,10 @@ speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: False switch_ir_optim: True
switch_ir_optim: False glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
...@@ -30,9 +31,10 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams) ...@@ -30,9 +31,10 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000
voc_predictor_conf: voc_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: False switch_ir_optim: True
switch_ir_optim: False glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
# OTHERS # # OTHERS #
......
...@@ -18,8 +18,8 @@ from .base_commands import BaseCommand ...@@ -18,8 +18,8 @@ from .base_commands import BaseCommand
from .base_commands import HelpCommand from .base_commands import HelpCommand
from .cls import CLSExecutor from .cls import CLSExecutor
from .st import STExecutor from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor from .text import TextExecutor
from .tts import TTSExecutor from .tts import TTSExecutor
from .stats import StatsExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os import os
import time
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import Any
from typing import List from typing import List
...@@ -621,6 +622,7 @@ class TTSExecutor(BaseExecutor): ...@@ -621,6 +622,7 @@ class TTSExecutor(BaseExecutor):
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False get_tone_ids = False
merge_sentences = False merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if lang == 'zh': if lang == 'zh':
...@@ -637,9 +639,13 @@ class TTSExecutor(BaseExecutor): ...@@ -637,9 +639,13 @@ class TTSExecutor(BaseExecutor):
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
self.am_time = 0
self.voc_time = 0
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# am # am
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
...@@ -653,13 +659,16 @@ class TTSExecutor(BaseExecutor): ...@@ -653,13 +659,16 @@ class TTSExecutor(BaseExecutor):
part_phone_ids, spk_id=paddle.to_tensor(spk_id)) part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else: else:
mel = self.am_inference(part_phone_ids) mel = self.am_inference(part_phone_ids)
self.am_time += (time.time() - am_st)
# voc # voc
voc_st = time.time()
wav = self.voc_inference(mel) wav = self.voc_inference(mel)
if flags == 0: if flags == 0:
wav_all = wav wav_all = wav
flags = 1 flags = 1
else: else:
wav_all = paddle.concat([wav_all, wav]) wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all self._outputs['wav'] = wav_all
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
......
...@@ -121,7 +121,6 @@ class TTSClientExecutor(BaseExecutor): ...@@ -121,7 +121,6 @@ class TTSClientExecutor(BaseExecutor):
(args.output)) (args.output))
logger.info("Audio duration: %f s." % (duration)) logger.info("Audio duration: %f s." % (duration))
logger.info("Response time: %f s." % (time_consume)) logger.info("Response time: %f s." % (time_consume))
logger.info("RTF: %f " % (time_consume / duration))
return True return True
except BaseException: except BaseException:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
################################################################## ##################################################################
# SERVER SETTING # # SERVER SETTING #
################################################################## ##################################################################
host: '0.0.0.0' host: '127.0.0.1'
port: 8090 port: 8090
################################################################## ##################################################################
......
...@@ -5,4 +5,4 @@ cfg_path: # [optional] ...@@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: True force_yes: True
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
...@@ -15,9 +15,10 @@ decode_method: ...@@ -15,9 +15,10 @@ decode_method:
force_yes: True force_yes: True
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
......
...@@ -29,4 +29,4 @@ voc_stat: ...@@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
...@@ -8,16 +8,17 @@ ...@@ -8,16 +8,17 @@
am: 'fastspeech2_csmsc' am: 'fastspeech2_csmsc'
am_model: # the pdmodel file of your am static model (XX.pdmodel) am_model: # the pdmodel file of your am static model (XX.pdmodel)
am_params: # the pdiparams file of your am static model (XX.pdipparams) am_params: # the pdiparams file of your am static model (XX.pdipparams)
am_sample_rate: 24000 # must match the model am_sample_rate: 24000
phones_dict: phones_dict:
tones_dict: tones_dict:
speaker_dict: speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: False switch_ir_optim: True
switch_ir_optim: False glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
...@@ -27,12 +28,13 @@ am_predictor_conf: ...@@ -27,12 +28,13 @@ am_predictor_conf:
voc: 'pwgan_csmsc' voc: 'pwgan_csmsc'
voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel) voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams) voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 #must match the model voc_sample_rate: 24000
voc_predictor_conf: voc_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
enable_mkldnn: False switch_ir_optim: True
switch_ir_optim: False glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
################################################################## ##################################################################
# OTHERS # # OTHERS #
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import io import io
import os import os
import time
from typing import Optional from typing import Optional
import paddle import paddle
...@@ -197,7 +198,6 @@ class ASREngine(BaseEngine): ...@@ -197,7 +198,6 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(paddle.get_device())
self.executor._init_from_path( self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
...@@ -223,13 +223,18 @@ class ASREngine(BaseEngine): ...@@ -223,13 +223,18 @@ class ASREngine(BaseEngine):
logger.info("start running asr engine") logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data)) io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model_type) self.executor.infer(self.config.model_type)
infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr. self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine") logger.info("end inferring asr engine")
else: else:
logger.info("file check failed!") logger.info("file check failed!")
self.output = None self.output = None
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: paddle inference")
def postprocess(self): def postprocess(self):
"""postprocess """postprocess
""" """
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import time
import paddle import paddle
...@@ -53,16 +54,24 @@ class ASREngine(BaseEngine): ...@@ -53,16 +54,24 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = get_config(config_file) self.config = get_config(config_file)
if self.config.device is None: try:
paddle.set_device(paddle.get_device()) if self.config.device:
self.device = self.config.device
else: else:
paddle.set_device(self.config.device) self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate, self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method, self.config.cfg_path, self.config.decode_method,
self.config.ckpt_path) self.config.ckpt_path)
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully on device: %s." %
(self.device))
return True return True
def run(self, audio_data): def run(self, audio_data):
...@@ -76,12 +85,17 @@ class ASREngine(BaseEngine): ...@@ -76,12 +85,17 @@ class ASREngine(BaseEngine):
self.config.force_yes): self.config.force_yes):
logger.info("start run asr engine") logger.info("start run asr engine")
self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) self.executor.preprocess(self.config.model, io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model) self.executor.infer(self.config.model)
infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr. self.output = self.executor.postprocess() # Retrieve result of asr.
else: else:
logger.info("file check failed!") logger.info("file check failed!")
self.output = None self.output = None
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: python")
def postprocess(self): def postprocess(self):
"""postprocess """postprocess
""" """
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import base64 import base64
import io import io
import os import os
import time
from typing import Optional from typing import Optional
import librosa import librosa
...@@ -179,7 +180,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -179,7 +180,7 @@ class TTSServerExecutor(TTSExecutor):
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_sample_rate = am_sample_rate self.am_sample_rate = am_sample_rate
self.am_res_path = os.path.dirname(os.path.abspath(self.am_model)) self.am_res_path = os.path.dirname(os.path.abspath(self.am_model))
print("self.phones_dict:", self.phones_dict) logger.info("self.phones_dict: {}".format(self.phones_dict))
# for speedyspeech # for speedyspeech
self.tones_dict = None self.tones_dict = None
...@@ -224,21 +225,21 @@ class TTSServerExecutor(TTSExecutor): ...@@ -224,21 +225,21 @@ class TTSServerExecutor(TTSExecutor):
with open(self.phones_dict, "r") as f: with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size) logger.info("vocab_size: {}".format(vocab_size))
tone_size = None tone_size = None
if self.tones_dict: if self.tones_dict:
with open(self.tones_dict, "r") as f: with open(self.tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) logger.info("tone_size: {}".format(tone_size))
spk_num = None spk_num = None
if self.speaker_dict: if self.speaker_dict:
with open(self.speaker_dict, 'rt') as f: with open(self.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num) logger.info("spk_num: {}".format(spk_num))
# frontend # frontend
if lang == 'zh': if lang == 'zh':
...@@ -248,21 +249,29 @@ class TTSServerExecutor(TTSExecutor): ...@@ -248,21 +249,29 @@ class TTSServerExecutor(TTSExecutor):
elif lang == 'en': elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!") logger.info("frontend done!")
try:
# am predictor # am predictor
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
logger.info("Create AM predictor successfully.")
except BaseException:
logger.error("Failed to create AM predictor.")
try:
# voc predictor # voc predictor
self.voc_predictor_conf = voc_predictor_conf self.voc_predictor_conf = voc_predictor_conf
self.voc_predictor = init_predictor( self.voc_predictor = init_predictor(
model_file=self.voc_model, model_file=self.voc_model,
params_file=self.voc_params, params_file=self.voc_params,
predictor_conf=self.voc_predictor_conf) predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.")
except BaseException:
logger.error("Failed to create Vocoder predictor.")
@paddle.no_grad() @paddle.no_grad()
def infer(self, def infer(self,
...@@ -277,6 +286,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -277,6 +286,7 @@ class TTSServerExecutor(TTSExecutor):
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False get_tone_ids = False
merge_sentences = False merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if lang == 'zh': if lang == 'zh':
...@@ -292,10 +302,14 @@ class TTSServerExecutor(TTSExecutor): ...@@ -292,10 +302,14 @@ class TTSServerExecutor(TTSExecutor):
text, merge_sentences=merge_sentences) text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") logger.error("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
self.am_time = 0
self.voc_time = 0
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# am # am
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
...@@ -314,7 +328,10 @@ class TTSServerExecutor(TTSExecutor): ...@@ -314,7 +328,10 @@ class TTSServerExecutor(TTSExecutor):
am_result = run_model(self.am_predictor, am_result = run_model(self.am_predictor,
[part_phone_ids.numpy()]) [part_phone_ids.numpy()])
mel = am_result[0] mel = am_result[0]
self.am_time += (time.time() - am_st)
# voc # voc
voc_st = time.time()
voc_result = run_model(self.voc_predictor, [mel]) voc_result = run_model(self.voc_predictor, [mel])
wav = voc_result[0] wav = voc_result[0]
wav = paddle.to_tensor(wav) wav = paddle.to_tensor(wav)
...@@ -324,6 +341,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -324,6 +341,7 @@ class TTSServerExecutor(TTSExecutor):
flags = 1 flags = 1
else: else:
wav_all = paddle.concat([wav_all, wav]) wav_all = paddle.concat([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all self._outputs['wav'] = wav_all
...@@ -370,7 +388,7 @@ class TTSEngine(BaseEngine): ...@@ -370,7 +388,7 @@ class TTSEngine(BaseEngine):
def postprocess(self, def postprocess(self,
wav, wav,
original_fs: int, original_fs: int,
target_fs: int=16000, target_fs: int=0,
volume: float=1.0, volume: float=1.0,
speed: float=1.0, speed: float=1.0,
audio_path: str=None): audio_path: str=None):
...@@ -395,38 +413,50 @@ class TTSEngine(BaseEngine): ...@@ -395,38 +413,50 @@ class TTSEngine(BaseEngine):
if target_fs == 0 or target_fs > original_fs: if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs target_fs = original_fs
wav_tar_fs = wav wav_tar_fs = wav
logger.info(
"The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs))
else: else:
wav_tar_fs = librosa.resample( wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs) np.squeeze(wav), original_fs, target_fs)
logger.info(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs))
# transform volume # transform volume
wav_vol = wav_tar_fs * volume wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.")
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.")
except ServerBaseException: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \ "Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException: except BaseException:
logger.error("Transform speed failed.") logger.error("Failed to transform speed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
wavfile.write(buf, target_fs, wav_speed) wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read()) base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8') wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.")
# save audio # save audio
if audio_path is not None and audio_path.endswith(".wav"): if audio_path is not None:
if audio_path.endswith(".wav"):
sf.write(audio_path, wav_speed, target_fs) sf.write(audio_path, wav_speed, target_fs)
elif audio_path is not None and audio_path.endswith(".pcm"): elif audio_path.endswith(".pcm"):
wav_norm = wav_speed * (32767 / max(0.001, wav_norm = wav_speed * (32767 / max(0.001,
np.max(np.abs(wav_speed)))) np.max(np.abs(wav_speed))))
with open(audio_path, "wb") as f: with open(audio_path, "wb") as f:
f.write(wav_norm.astype(np.int16)) f.write(wav_norm.astype(np.int16))
logger.info("Save audio to {} successfully.".format(audio_path))
else:
logger.info("There is no need to save audio.")
return target_fs, wav_base64 return target_fs, wav_base64
...@@ -462,8 +492,12 @@ class TTSEngine(BaseEngine): ...@@ -462,8 +492,12 @@ class TTSEngine(BaseEngine):
lang = self.config.lang lang = self.config.lang
try: try:
infer_st = time.time()
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
...@@ -471,6 +505,7 @@ class TTSEngine(BaseEngine): ...@@ -471,6 +505,7 @@ class TTSEngine(BaseEngine):
logger.error("tts infer failed.") logger.error("tts infer failed.")
try: try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(), wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_sample_rate, original_fs=self.executor.am_sample_rate,
...@@ -478,10 +513,34 @@ class TTSEngine(BaseEngine): ...@@ -478,10 +513,34 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_sample_rate
rtf = infer_time / duration
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException: except BaseException:
logger.error("tts postprocess failed.") logger.error("tts postprocess failed.")
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: paddle inference")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
format(postprocess_time))
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import base64 import base64
import io import io
import time
import librosa import librosa
import numpy as np import numpy as np
...@@ -54,11 +55,20 @@ class TTSEngine(BaseEngine): ...@@ -54,11 +55,20 @@ class TTSEngine(BaseEngine):
try: try:
self.config = get_config(config_file) self.config = get_config(config_file)
if self.config.device is None: if self.config.device:
paddle.set_device(paddle.get_device()) self.device = self.config.device
else: else:
paddle.set_device(self.config.device) self.device = paddle.get_device()
paddle.set_device(self.device)
except BaseException:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
)
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False
try:
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
am_config=self.config.am_config, am_config=self.config.am_config,
...@@ -73,16 +83,19 @@ class TTSEngine(BaseEngine): ...@@ -73,16 +83,19 @@ class TTSEngine(BaseEngine):
voc_stat=self.config.voc_stat, voc_stat=self.config.voc_stat,
lang=self.config.lang) lang=self.config.lang)
except BaseException: except BaseException:
logger.error("Initialize TTS server engine Failed.") logger.error("Failed to get model related files.")
logger.error("Initialize TTS server engine Failed on device: %s." %
(self.device))
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True return True
def postprocess(self, def postprocess(self,
wav, wav,
original_fs: int, original_fs: int,
target_fs: int=16000, target_fs: int=0,
volume: float=1.0, volume: float=1.0,
speed: float=1.0, speed: float=1.0,
audio_path: str=None): audio_path: str=None):
...@@ -107,38 +120,50 @@ class TTSEngine(BaseEngine): ...@@ -107,38 +120,50 @@ class TTSEngine(BaseEngine):
if target_fs == 0 or target_fs > original_fs: if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs target_fs = original_fs
wav_tar_fs = wav wav_tar_fs = wav
logger.info(
"The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs))
else: else:
wav_tar_fs = librosa.resample( wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs) np.squeeze(wav), original_fs, target_fs)
logger.info(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs))
# transform volume # transform volume
wav_vol = wav_tar_fs * volume wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.")
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.")
except ServerBaseException: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \ "Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException: except BaseException:
logger.error("Transform speed failed.") logger.error("Failed to transform speed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
wavfile.write(buf, target_fs, wav_speed) wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read()) base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8') wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.")
# save audio # save audio
if audio_path is not None and audio_path.endswith(".wav"): if audio_path is not None:
if audio_path.endswith(".wav"):
sf.write(audio_path, wav_speed, target_fs) sf.write(audio_path, wav_speed, target_fs)
elif audio_path is not None and audio_path.endswith(".pcm"): elif audio_path.endswith(".pcm"):
wav_norm = wav_speed * (32767 / max(0.001, wav_norm = wav_speed * (32767 / max(0.001,
np.max(np.abs(wav_speed)))) np.max(np.abs(wav_speed))))
with open(audio_path, "wb") as f: with open(audio_path, "wb") as f:
f.write(wav_norm.astype(np.int16)) f.write(wav_norm.astype(np.int16))
logger.info("Save audio to {} successfully.".format(audio_path))
else:
logger.info("There is no need to save audio.")
return target_fs, wav_base64 return target_fs, wav_base64
...@@ -174,8 +199,15 @@ class TTSEngine(BaseEngine): ...@@ -174,8 +199,15 @@ class TTSEngine(BaseEngine):
lang = self.config.lang lang = self.config.lang
try: try:
infer_st = time.time()
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
infer_et = time.time()
infer_time = infer_et - infer_st
duration = len(self.executor._outputs['wav']
.numpy()) / self.executor.am_config.fs
rtf = infer_time / duration
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
...@@ -183,6 +215,7 @@ class TTSEngine(BaseEngine): ...@@ -183,6 +215,7 @@ class TTSEngine(BaseEngine):
logger.error("tts infer failed.") logger.error("tts infer failed.")
try: try:
postprocess_st = time.time()
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
wav=self.executor._outputs['wav'].numpy(), wav=self.executor._outputs['wav'].numpy(),
original_fs=self.executor.am_config.fs, original_fs=self.executor.am_config.fs,
...@@ -190,10 +223,32 @@ class TTSEngine(BaseEngine): ...@@ -190,10 +223,32 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
postprocess_et = time.time()
postprocess_time = postprocess_et - postprocess_st
except ServerBaseException: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException: except BaseException:
logger.error("tts postprocess failed.") logger.error("tts postprocess failed.")
logger.info("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang))
logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration))
logger.info(
"frontend inference time: {}".format(self.executor.frontend_time))
logger.info("AM inference time: {}".format(self.executor.am_time))
logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
logger.info("total inference time: {}".format(infer_time))
logger.info(
"postprocess (change speed, volume, target sample rate) time: {}".
format(postprocess_time))
logger.info("total generate audio time: {}".format(infer_time +
postprocess_time))
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device))
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64
...@@ -16,6 +16,7 @@ from typing import Union ...@@ -16,6 +16,7 @@ from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import TTSRequest from paddlespeech.server.restful.request import TTSRequest
from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.restful.response import ErrorResponse
...@@ -60,6 +61,9 @@ def tts(request_body: TTSRequest): ...@@ -60,6 +61,9 @@ def tts(request_body: TTSRequest):
Returns: Returns:
json: [description] json: [description]
""" """
logger.info("request: {}".format(request_body))
# get params # get params
text = request_body.text text = request_body.text
spk_id = request_body.spk_id spk_id = request_body.spk_id
...@@ -92,6 +96,7 @@ def tts(request_body: TTSRequest): ...@@ -92,6 +96,7 @@ def tts(request_body: TTSRequest):
# get single engine from engine pool # get single engine from engine pool
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
tts_engine = engine_pool['tts'] tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
lang, target_sample_rate, wav_base64 = tts_engine.run( lang, target_sample_rate, wav_base64 = tts_engine.run(
text, spk_id, speed, volume, sample_rate, save_path) text, spk_id, speed, volume, sample_rate, save_path)
......
...@@ -15,6 +15,7 @@ import os ...@@ -15,6 +15,7 @@ import os
from typing import List from typing import List
from typing import Optional from typing import Optional
import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
...@@ -40,15 +41,30 @@ def init_predictor(model_dir: Optional[os.PathLike]=None, ...@@ -40,15 +41,30 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
else: else:
config = Config(model_file, params_file) config = Config(model_file, params_file)
config.enable_memory_optim() # set device
if "gpu" in predictor_conf["device"]: if predictor_conf["device"]:
gpu_id = predictor_conf["device"].split(":")[-1] device = predictor_conf["device"]
else:
device = paddle.get_device()
if "gpu" in device:
gpu_id = device.split(":")[-1]
config.enable_use_gpu(1000, int(gpu_id)) config.enable_use_gpu(1000, int(gpu_id))
if predictor_conf["enable_mkldnn"]:
config.enable_mkldnn() # IR optim
if predictor_conf["switch_ir_optim"]: if predictor_conf["switch_ir_optim"]:
config.switch_ir_optim() config.switch_ir_optim()
# glog
if not predictor_conf["glog_info"]:
config.disable_glog_info()
# config summary
if predictor_conf["summary"]:
print(config.summary())
# memory optim
config.enable_memory_optim()
predictor = create_predictor(config) predictor = create_predictor(config)
return predictor return predictor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册