tts_engine.py 18.7 KB
Newer Older
L
lym0302 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import os
L
lym0302 已提交
17
import time
L
lym0302 已提交
18 19 20 21 22 23 24 25
from typing import Optional

import librosa
import numpy as np
import paddle
import soundfile as sf
from scipy.io import wavfile

L
lym0302 已提交
26
from .pretrained_models import pretrained_models
L
lym0302 已提交
27 28
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
L
lym0302 已提交
29 30 31 32 33 34
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
L
lym0302 已提交
35 36
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
L
lym0302 已提交
37 38 39 40 41 42 43

__all__ = ['TTSEngine']


class TTSServerExecutor(TTSExecutor):
    def __init__(self):
        super().__init__()
L
lym0302 已提交
44
        self.pretrained_models = pretrained_models
L
lym0302 已提交
45 46 47 48 49 50

    def _init_from_path(
            self,
            am: str='fastspeech2_csmsc',
            am_model: Optional[os.PathLike]=None,
            am_params: Optional[os.PathLike]=None,
51
            am_sample_rate: int=24000,
L
lym0302 已提交
52 53 54 55 56 57
            phones_dict: Optional[os.PathLike]=None,
            tones_dict: Optional[os.PathLike]=None,
            speaker_dict: Optional[os.PathLike]=None,
            voc: str='pwgan_csmsc',
            voc_model: Optional[os.PathLike]=None,
            voc_params: Optional[os.PathLike]=None,
58
            voc_sample_rate: int=24000,
L
lym0302 已提交
59 60 61 62 63 64
            lang: str='zh',
            am_predictor_conf: dict=None,
            voc_predictor_conf: dict=None, ):
        """
        Init model and other resources from a specific path.
        """
L
lym0302 已提交
65
        if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'):
L
lym0302 已提交
66 67 68 69
            logger.info('Models had been initialized.')
            return
        # am
        am_tag = am + '-' + lang
70
        if am_model is None or am_params is None or phones_dict is None:
L
lym0302 已提交
71 72
            am_res_path = self._get_pretrained_path(am_tag)
            self.am_res_path = am_res_path
L
lym0302 已提交
73 74 75 76
            self.am_model = os.path.join(
                am_res_path, self.pretrained_models[am_tag]['model'])
            self.am_params = os.path.join(
                am_res_path, self.pretrained_models[am_tag]['params'])
L
lym0302 已提交
77
            # must have phones_dict in acoustic
78
            self.phones_dict = os.path.join(
L
lym0302 已提交
79 80
                am_res_path, self.pretrained_models[am_tag]['phones_dict'])
            self.am_sample_rate = self.pretrained_models[am_tag]['sample_rate']
L
lym0302 已提交
81 82 83 84 85 86 87 88

            logger.info(am_res_path)
            logger.info(self.am_model)
            logger.info(self.am_params)
        else:
            self.am_model = os.path.abspath(am_model)
            self.am_params = os.path.abspath(am_params)
            self.phones_dict = os.path.abspath(phones_dict)
89
            self.am_sample_rate = am_sample_rate
L
lym0302 已提交
90
            self.am_res_path = os.path.dirname(os.path.abspath(self.am_model))
L
lym0302 已提交
91
        logger.info("self.phones_dict: {}".format(self.phones_dict))
L
lym0302 已提交
92 93 94

        # for speedyspeech
        self.tones_dict = None
L
lym0302 已提交
95
        if 'tones_dict' in self.pretrained_models[am_tag]:
L
lym0302 已提交
96
            self.tones_dict = os.path.join(
L
lym0302 已提交
97
                am_res_path, self.pretrained_models[am_tag]['tones_dict'])
L
lym0302 已提交
98 99 100 101 102
            if tones_dict:
                self.tones_dict = tones_dict

        # for multi speaker fastspeech2
        self.speaker_dict = None
L
lym0302 已提交
103
        if 'speaker_dict' in self.pretrained_models[am_tag]:
L
lym0302 已提交
104
            self.speaker_dict = os.path.join(
L
lym0302 已提交
105
                am_res_path, self.pretrained_models[am_tag]['speaker_dict'])
L
lym0302 已提交
106 107 108 109 110 111 112 113
            if speaker_dict:
                self.speaker_dict = speaker_dict

        # voc
        voc_tag = voc + '-' + lang
        if voc_model is None or voc_params is None:
            voc_res_path = self._get_pretrained_path(voc_tag)
            self.voc_res_path = voc_res_path
L
lym0302 已提交
114 115 116 117 118 119
            self.voc_model = os.path.join(
                voc_res_path, self.pretrained_models[voc_tag]['model'])
            self.voc_params = os.path.join(
                voc_res_path, self.pretrained_models[voc_tag]['params'])
            self.voc_sample_rate = self.pretrained_models[voc_tag][
                'sample_rate']
L
lym0302 已提交
120 121 122 123 124 125
            logger.info(voc_res_path)
            logger.info(self.voc_model)
            logger.info(self.voc_params)
        else:
            self.voc_model = os.path.abspath(voc_model)
            self.voc_params = os.path.abspath(voc_params)
126
            self.voc_sample_rate = voc_sample_rate
L
lym0302 已提交
127 128
            self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_model))

L
lym0302 已提交
129 130 131 132
        assert (
            self.voc_sample_rate == self.am_sample_rate
        ), "The sample rate of AM and Vocoder model are different, please check model."

L
lym0302 已提交
133 134 135 136
        # Init body.
        with open(self.phones_dict, "r") as f:
            phn_id = [line.strip().split() for line in f.readlines()]
        vocab_size = len(phn_id)
L
lym0302 已提交
137
        logger.info("vocab_size: {}".format(vocab_size))
L
lym0302 已提交
138 139 140 141 142 143

        tone_size = None
        if self.tones_dict:
            with open(self.tones_dict, "r") as f:
                tone_id = [line.strip().split() for line in f.readlines()]
            tone_size = len(tone_id)
L
lym0302 已提交
144
            logger.info("tone_size: {}".format(tone_size))
L
lym0302 已提交
145 146 147 148 149 150

        spk_num = None
        if self.speaker_dict:
            with open(self.speaker_dict, 'rt') as f:
                spk_id = [line.strip().split() for line in f.readlines()]
            spk_num = len(spk_id)
L
lym0302 已提交
151
            logger.info("spk_num: {}".format(spk_num))
L
lym0302 已提交
152 153 154 155 156 157 158 159 160

        # frontend
        if lang == 'zh':
            self.frontend = Frontend(
                phone_vocab_path=self.phones_dict,
                tone_vocab_path=self.tones_dict)

        elif lang == 'en':
            self.frontend = English(phone_vocab_path=self.phones_dict)
L
lym0302 已提交
161 162
        logger.info("frontend done!")

L
lym0302 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        # Create am predictor
        self.am_predictor_conf = am_predictor_conf
        self.am_predictor = init_predictor(
            model_file=self.am_model,
            params_file=self.am_params,
            predictor_conf=self.am_predictor_conf)
        logger.info("Create AM predictor successfully.")

        # Create voc predictor
        self.voc_predictor_conf = voc_predictor_conf
        self.voc_predictor = init_predictor(
            model_file=self.voc_model,
            params_file=self.voc_params,
            predictor_conf=self.voc_predictor_conf)
        logger.info("Create Vocoder predictor successfully.")
L
lym0302 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191

    @paddle.no_grad()
    def infer(self,
              text: str,
              lang: str='zh',
              am: str='fastspeech2_csmsc',
              spk_id: int=0):
        """
        Model inference and result stored in self.output.
        """
        am_name = am[:am.rindex('_')]
        am_dataset = am[am.rindex('_') + 1:]
        get_tone_ids = False
        merge_sentences = False
L
lym0302 已提交
192
        frontend_st = time.time()
L
lym0302 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
        if am_name == 'speedyspeech':
            get_tone_ids = True
        if lang == 'zh':
            input_ids = self.frontend.get_input_ids(
                text,
                merge_sentences=merge_sentences,
                get_tone_ids=get_tone_ids)
            phone_ids = input_ids["phone_ids"]
            if get_tone_ids:
                tone_ids = input_ids["tone_ids"]
        elif lang == 'en':
            input_ids = self.frontend.get_input_ids(
                text, merge_sentences=merge_sentences)
            phone_ids = input_ids["phone_ids"]
        else:
L
lym0302 已提交
208 209
            logger.error("lang should in {'zh', 'en'}!")
        self.frontend_time = time.time() - frontend_st
L
lym0302 已提交
210

L
lym0302 已提交
211 212
        self.am_time = 0
        self.voc_time = 0
L
lym0302 已提交
213 214
        flags = 0
        for i in range(len(phone_ids)):
L
lym0302 已提交
215
            am_st = time.time()
L
lym0302 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
            part_phone_ids = phone_ids[i]
            # am
            if am_name == 'speedyspeech':
                part_tone_ids = tone_ids[i]
                am_result = run_model(
                    self.am_predictor,
                    [part_phone_ids.numpy(), part_tone_ids.numpy()])
                mel = am_result[0]

            # fastspeech2
            else:
                # multi speaker  do not have static model
                if am_dataset in {"aishell3", "vctk"}:
                    pass
                else:
                    am_result = run_model(self.am_predictor,
                                          [part_phone_ids.numpy()])
                    mel = am_result[0]
L
lym0302 已提交
234 235
            self.am_time += (time.time() - am_st)

L
lym0302 已提交
236
            # voc
L
lym0302 已提交
237
            voc_st = time.time()
L
lym0302 已提交
238 239 240 241 242 243 244 245 246
            voc_result = run_model(self.voc_predictor, [mel])
            wav = voc_result[0]
            wav = paddle.to_tensor(wav)

            if flags == 0:
                wav_all = wav
                flags = 1
            else:
                wav_all = paddle.concat([wav_all, wav])
L
lym0302 已提交
247
            self.voc_time += (time.time() - voc_st)
L
lym0302 已提交
248 249 250 251 252 253 254 255 256 257
        self._outputs['wav'] = wav_all


class TTSEngine(BaseEngine):
    """TTS server engine

    Args:
        metaclass: Defaults to Singleton.
    """

L
lym0302 已提交
258
    def __init__(self):
L
lym0302 已提交
259 260 261 262
        """Initialize TTS server engine
        """
        super(TTSEngine, self).__init__()

L
lym0302 已提交
263
    def init(self, config: dict) -> bool:
L
lym0302 已提交
264
        self.executor = TTSServerExecutor()
L
lym0302 已提交
265
        self.config = config
L
lym0302 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282

        try:
            if self.config.am_predictor_conf.device is not None:
                self.device = self.config.am_predictor_conf.device
            elif self.config.voc_predictor_conf.device is not None:
                self.device = self.config.voc_predictor_conf.device
            else:
                self.device = paddle.get_device()
            paddle.set_device(self.device)
        except BaseException as e:
            logger.error(
                "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
            )
            logger.error("Initialize TTS server engine Failed on device: %s." %
                         (self.device))
            return False

L
lym0302 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
        self.executor._init_from_path(
            am=self.config.am,
            am_model=self.config.am_model,
            am_params=self.config.am_params,
            am_sample_rate=self.config.am_sample_rate,
            phones_dict=self.config.phones_dict,
            tones_dict=self.config.tones_dict,
            speaker_dict=self.config.speaker_dict,
            voc=self.config.voc,
            voc_model=self.config.voc_model,
            voc_params=self.config.voc_params,
            voc_sample_rate=self.config.voc_sample_rate,
            lang=self.config.lang,
            am_predictor_conf=self.config.am_predictor_conf,
            voc_predictor_conf=self.config.voc_predictor_conf, )
L
lym0302 已提交
298

L
lym0302 已提交
299 300 301 302 303 304 305 306
        # warm up
        try:
            self.warm_up()
            logger.info("Warm up successfully.")
        except Exception as e:
            logger.error("Failed to warm up on tts engine.")
            return False

L
lym0302 已提交
307
        logger.info("Initialize TTS server engine successfully.")
L
lym0302 已提交
308
        return True
L
lym0302 已提交
309

L
lym0302 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    def warm_up(self):
        """warm up
        """
        if self.config.lang == 'zh':
            sentence = "您好,欢迎使用语音合成服务。"
        if self.config.lang == 'en':
            sentence = "Hello and welcome to the speech synthesis service."
        logger.info("Start to warm up.")
        for i in range(3):
            st = time.time()
            self.executor.infer(
                text=sentence,
                lang=self.config.lang,
                am=self.config.am,
                spk_id=0, )
            logger.info(
                f"The response time of the {i} warm up: {time.time() - st} s")

L
lym0302 已提交
328 329 330
    def postprocess(self,
                    wav,
                    original_fs: int,
L
lym0302 已提交
331
                    target_fs: int=0,
L
lym0302 已提交
332 333 334 335 336 337 338 339 340 341 342
                    volume: float=1.0,
                    speed: float=1.0,
                    audio_path: str=None):
        """Post-processing operations, including speech, volume, sample rate, save audio file

        Args:
            wav (numpy(float)): Synthesized audio sample points
            original_fs (int): original audio sample rate
            target_fs (int): target audio sample rate
            volume (float): target volume
            speed (float): target speed
L
lym0302 已提交
343 344 345 346 347 348 349

        Raises:
            ServerBaseException: Throws an exception if the change speed unsuccessfully.

        Returns:
            target_fs: target sample rate for synthesized audio.
            wav_base64: The base64 format of the synthesized audio.
L
lym0302 已提交
350 351 352 353 354 355
        """

        # transform sample_rate
        if target_fs == 0 or target_fs > original_fs:
            target_fs = original_fs
            wav_tar_fs = wav
L
lym0302 已提交
356 357 358
            logger.info(
                "The sample rate of synthesized audio is the same as model, which is {}Hz".
                format(original_fs))
L
lym0302 已提交
359 360 361
        else:
            wav_tar_fs = librosa.resample(
                np.squeeze(wav), original_fs, target_fs)
L
lym0302 已提交
362 363 364
            logger.info(
                "The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
                format(original_fs, target_fs))
L
lym0302 已提交
365 366
        # transform volume
        wav_vol = wav_tar_fs * volume
L
lym0302 已提交
367
        logger.info("Transform the volume of the audio successfully.")
L
lym0302 已提交
368 369 370 371

        # transform speed
        try:  # windows not support soxbindings
            wav_speed = change_speed(wav_vol, speed, target_fs)
L
lym0302 已提交
372
            logger.info("Transform the speed of the audio successfully.")
L
lym0302 已提交
373
        except ServerBaseException:
L
lym0302 已提交
374 375
            raise ServerBaseException(
                ErrorCode.SERVER_INTERNAL_ERR,
L
lym0302 已提交
376
                "Failed to transform speed. Can not install soxbindings on your system. \
L
lym0302 已提交
377
                 You need to set speed value 1.0.")
L
lym0302 已提交
378
        except BaseException:
L
lym0302 已提交
379
            logger.error("Failed to transform speed.")
L
lym0302 已提交
380 381 382 383 384 385

        # wav to base64
        buf = io.BytesIO()
        wavfile.write(buf, target_fs, wav_speed)
        base64_bytes = base64.b64encode(buf.read())
        wav_base64 = base64_bytes.decode('utf-8')
L
lym0302 已提交
386
        logger.info("Audio to string successfully.")
L
lym0302 已提交
387 388

        # save audio
L
lym0302 已提交
389 390 391 392 393 394 395 396 397 398 399
        if audio_path is not None:
            if audio_path.endswith(".wav"):
                sf.write(audio_path, wav_speed, target_fs)
            elif audio_path.endswith(".pcm"):
                wav_norm = wav_speed * (32767 / max(0.001,
                                                    np.max(np.abs(wav_speed))))
                with open(audio_path, "wb") as f:
                    f.write(wav_norm.astype(np.int16))
            logger.info("Save audio to {} successfully.".format(audio_path))
        else:
            logger.info("There is no need to save audio.")
L
lym0302 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422

        return target_fs, wav_base64

    def run(self,
            sentence: str,
            spk_id: int=0,
            speed: float=1.0,
            volume: float=1.0,
            sample_rate: int=0,
            save_path: str=None):
        """get the result of the server response

        Args:
            sentence (str): sentence to be synthesized
            spk_id (int, optional): speaker id. Defaults to 0.
            speed (float, optional): audio speed, 0 < speed <=3.0. Defaults to 1.0.
            volume (float, optional): The volume relative to the audio synthesized by the model, 
            0 < volume <=3.0. Defaults to 1.0.
            sample_rate (int, optional): Set the sample rate of the synthesized audio. 
            0 represents the sample rate for model synthesis. Defaults to 0.
            save_path (str, optional): The save path of the synthesized audio. Defaults to None.

        Raises:
L
lym0302 已提交
423 424
            ServerBaseException: Throws an exception if tts inference unsuccessfully.
            ServerBaseException: Throws an exception if postprocess unsuccessfully.
L
lym0302 已提交
425 426

        Returns:
L
lym0302 已提交
427 428 429
            lang: model language 
            target_sample_rate: target sample rate for synthesized audio.
            wav_base64: The base64 format of the synthesized audio.
L
lym0302 已提交
430 431
        """

L
lym0302 已提交
432
        lang = self.config.lang
L
lym0302 已提交
433 434

        try:
L
lym0302 已提交
435
            infer_st = time.time()
L
lym0302 已提交
436
            self.executor.infer(
L
lym0302 已提交
437
                text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
L
lym0302 已提交
438 439 440
            infer_et = time.time()
            infer_time = infer_et - infer_st

L
lym0302 已提交
441
        except ServerBaseException:
L
lym0302 已提交
442 443
            raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
                                      "tts infer failed.")
L
lym0302 已提交
444 445
        except BaseException:
            logger.error("tts infer failed.")
L
lym0302 已提交
446 447

        try:
L
lym0302 已提交
448
            postprocess_st = time.time()
L
lym0302 已提交
449 450
            target_sample_rate, wav_base64 = self.postprocess(
                wav=self.executor._outputs['wav'].numpy(),
451
                original_fs=self.executor.am_sample_rate,
L
lym0302 已提交
452 453 454 455
                target_fs=sample_rate,
                volume=volume,
                speed=speed,
                audio_path=save_path)
L
lym0302 已提交
456 457 458 459 460 461
            postprocess_et = time.time()
            postprocess_time = postprocess_et - postprocess_st
            duration = len(self.executor._outputs['wav']
                           .numpy()) / self.executor.am_sample_rate
            rtf = infer_time / duration

L
lym0302 已提交
462
        except ServerBaseException:
L
lym0302 已提交
463 464
            raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
                                      "tts postprocess failed.")
L
lym0302 已提交
465 466
        except BaseException:
            logger.error("tts postprocess failed.")
L
lym0302 已提交
467

L
lym0302 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
        logger.info("AM model: {}".format(self.config.am))
        logger.info("Vocoder model: {}".format(self.config.voc))
        logger.info("Language: {}".format(lang))
        logger.info("tts engine type: paddle inference")

        logger.info("audio duration: {}".format(duration))
        logger.info(
            "frontend inference time: {}".format(self.executor.frontend_time))
        logger.info("AM inference time: {}".format(self.executor.am_time))
        logger.info("Vocoder inference time: {}".format(self.executor.voc_time))
        logger.info("total inference time: {}".format(infer_time))
        logger.info(
            "postprocess (change speed, volume, target sample rate) time: {}".
            format(postprocess_time))
        logger.info("total generate audio time: {}".format(infer_time +
                                                           postprocess_time))
        logger.info("RTF: {}".format(rtf))

L
lym0302 已提交
486
        return lang, target_sample_rate, duration, wav_base64