tts_engine.py 19.5 KB
Newer Older
L
lym0302 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
L
lym0302 已提交
15 16
import math
import os
L
lym0302 已提交
17
import time
L
lym0302 已提交
18
from typing import Optional
L
lym0302 已提交
19 20 21 22 23 24

import numpy as np
import paddle

from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
K
KP 已提交
25
from paddlespeech.resource import CommonTaskResource
L
lym0302 已提交
26 27
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm
L
lym0302 已提交
28
from paddlespeech.server.utils.onnx_infer import get_sess
L
lym0302 已提交
29
from paddlespeech.server.utils.util import denorm
L
lym0302 已提交
30
from paddlespeech.server.utils.util import get_chunks
L
lym0302 已提交
31 32 33
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend

L
lym0302 已提交
34
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
L
lym0302 已提交
35

L
lym0302 已提交
36 37

class TTSServerExecutor(TTSExecutor):
L
lym0302 已提交
38
    def __init__(self):
L
lym0302 已提交
39
        super().__init__()
K
KP 已提交
40
        self.task_resource = CommonTaskResource(task='tts', model_format='onnx')
L
lym0302 已提交
41 42 43

    def _init_from_path(
            self,
L
lym0302 已提交
44 45
            am: str='fastspeech2_csmsc_onnx',
            am_ckpt: Optional[list]=None,
L
lym0302 已提交
46 47 48 49
            am_stat: Optional[os.PathLike]=None,
            phones_dict: Optional[os.PathLike]=None,
            tones_dict: Optional[os.PathLike]=None,
            speaker_dict: Optional[os.PathLike]=None,
L
lym0302 已提交
50 51 52
            am_sample_rate: int=24000,
            am_sess_conf: dict=None,
            voc: str='mb_melgan_csmsc_onnx',
L
lym0302 已提交
53
            voc_ckpt: Optional[os.PathLike]=None,
L
lym0302 已提交
54 55
            voc_sample_rate: int=24000,
            voc_sess_conf: dict=None,
L
lym0302 已提交
56 57 58 59
            lang: str='zh', ):
        """
        Init model and other resources from a specific path.
        """
L
lym0302 已提交
60 61 62 63 64

        if (hasattr(self, 'am_sess') or
            (hasattr(self, 'am_encoder_infer_sess') and
             hasattr(self, 'am_decoder_sess') and hasattr(
                 self, 'am_postnet_sess'))) and hasattr(self, 'voc_inference'):
L
lym0302 已提交
65 66
            logger.info('Models had been initialized.')
            return
L
lym0302 已提交
67
        # am
L
lym0302 已提交
68
        am_tag = am + '-' + lang
K
KP 已提交
69 70 71 72 73 74
        self.task_resource.set_task_model(
            model_tag=am_tag,
            model_type=0,  # am
            version=None,  # default version
        )
        self.am_res_path = self.task_resource.res_dir
L
lym0302 已提交
75 76 77 78
        if am == "fastspeech2_csmsc_onnx":
            # get model info
            if am_ckpt is None or phones_dict is None:
                self.am_ckpt = os.path.join(
K
KP 已提交
79
                    self.am_res_path, self.task_resource.res_dict['ckpt'][0])
L
lym0302 已提交
80 81
                # must have phones_dict in acoustic
                self.phones_dict = os.path.join(
K
KP 已提交
82 83
                    self.am_res_path,
                    self.task_resource.res_dict['phones_dict'])
L
lym0302 已提交
84

L
lym0302 已提交
85 86 87 88 89 90 91 92 93 94 95 96
            else:
                self.am_ckpt = os.path.abspath(am_ckpt[0])
                self.phones_dict = os.path.abspath(phones_dict)
                self.am_res_path = os.path.dirname(
                    os.path.abspath(self.am_ckpt))

            # create am sess
            self.am_sess = get_sess(self.am_ckpt, am_sess_conf)

        elif am == "fastspeech2_cnndecoder_csmsc_onnx":
            if am_ckpt is None or am_stat is None or phones_dict is None:
                self.am_encoder_infer = os.path.join(
K
KP 已提交
97
                    self.am_res_path, self.task_resource.res_dict['ckpt'][0])
L
lym0302 已提交
98
                self.am_decoder = os.path.join(
K
KP 已提交
99
                    self.am_res_path, self.task_resource.res_dict['ckpt'][1])
L
lym0302 已提交
100
                self.am_postnet = os.path.join(
K
KP 已提交
101
                    self.am_res_path, self.task_resource.res_dict['ckpt'][2])
L
lym0302 已提交
102 103
                # must have phones_dict in acoustic
                self.phones_dict = os.path.join(
K
KP 已提交
104 105
                    self.am_res_path,
                    self.task_resource.res_dict['phones_dict'])
L
lym0302 已提交
106
                self.am_stat = os.path.join(
K
KP 已提交
107 108
                    self.am_res_path,
                    self.task_resource.res_dict['speech_stats'])
L
lym0302 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

            else:
                self.am_encoder_infer = os.path.abspath(am_ckpt[0])
                self.am_decoder = os.path.abspath(am_ckpt[1])
                self.am_postnet = os.path.abspath(am_ckpt[2])
                self.phones_dict = os.path.abspath(phones_dict)
                self.am_stat = os.path.abspath(am_stat)
                self.am_res_path = os.path.dirname(
                    os.path.abspath(self.am_ckpt))

            # create am sess
            self.am_encoder_infer_sess = get_sess(self.am_encoder_infer,
                                                  am_sess_conf)
            self.am_decoder_sess = get_sess(self.am_decoder, am_sess_conf)
            self.am_postnet_sess = get_sess(self.am_postnet, am_sess_conf)

            self.am_mu, self.am_std = np.load(self.am_stat)

        logger.info(f"self.phones_dict: {self.phones_dict}")
        logger.info(f"am model dir: {self.am_res_path}")
        logger.info("Create am sess successfully.")
L
lym0302 已提交
130 131 132

        # voc model info
        voc_tag = voc + '-' + lang
K
KP 已提交
133 134 135 136 137
        self.task_resource.set_task_model(
            model_tag=voc_tag,
            model_type=1,  # vocoder
            version=None,  # default version
        )
L
lym0302 已提交
138
        if voc_ckpt is None:
K
KP 已提交
139
            self.voc_res_path = self.task_resource.voc_res_dir
L
lym0302 已提交
140
            self.voc_ckpt = os.path.join(
K
KP 已提交
141
                self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
L
lym0302 已提交
142 143
        else:
            self.voc_ckpt = os.path.abspath(voc_ckpt)
L
lym0302 已提交
144 145
            self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
        logger.info(self.voc_res_path)
L
lym0302 已提交
146

L
lym0302 已提交
147 148 149
        # create voc sess
        self.voc_sess = get_sess(self.voc_ckpt, voc_sess_conf)
        logger.info("Create voc sess successfully.")
L
lym0302 已提交
150 151 152 153

        with open(self.phones_dict, "r") as f:
            phn_id = [line.strip().split() for line in f.readlines()]
        self.vocab_size = len(phn_id)
L
lym0302 已提交
154
        logger.info(f"vocab_size: {self.vocab_size}")
L
lym0302 已提交
155 156

        # frontend
L
lym0302 已提交
157
        self.tones_dict = None
L
lym0302 已提交
158 159 160 161 162 163 164
        if lang == 'zh':
            self.frontend = Frontend(
                phone_vocab_path=self.phones_dict,
                tone_vocab_path=self.tones_dict)

        elif lang == 'en':
            self.frontend = English(phone_vocab_path=self.phones_dict)
L
lym0302 已提交
165
        logger.info("frontend done!")
L
lym0302 已提交
166

L
lym0302 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

class TTSEngine(BaseEngine):
    """TTS server engine

    Args:
        metaclass: Defaults to Singleton.
    """

    def __init__(self, name=None):
        """Initialize TTS server engine
        """
        super().__init__()

    def init(self, config: dict) -> bool:
        self.executor = TTSServerExecutor()
        self.config = config
        self.lang = self.config.lang
        self.engine_type = "online-onnx"

        self.am_block = self.config.am_block
        self.am_pad = self.config.am_pad
        self.voc_block = self.config.voc_block
        self.voc_pad = self.config.voc_pad
        self.am_upsample = 1
        self.voc_upsample = self.config.voc_upsample

        assert (
            self.config.am == "fastspeech2_csmsc_onnx" or
            self.config.am == "fastspeech2_cnndecoder_csmsc_onnx"
        ) and (
            self.config.voc == "hifigan_csmsc_onnx" or
            self.config.voc == "mb_melgan_csmsc_onnx"
        ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'

        assert (
            self.config.voc_block > 0 and self.config.voc_pad > 0
        ), "Please set correct voc_block and voc_pad, they should be more than 0."

        assert (
            self.config.voc_sample_rate == self.config.am_sample_rate
        ), "The sample rate of AM and Vocoder model are different, please check model."

        try:
            if self.config.am_sess_conf.device is not None:
                self.device = self.config.am_sess_conf.device
            elif self.config.voc_sess_conf.device is not None:
                self.device = self.config.voc_sess_conf.device
            else:
                self.device = paddle.get_device()
            paddle.set_device(self.device)
        except Exception as e:
            logger.error(
                "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
            )
            logger.error("Initialize TTS server engine Failed on device: %s." %
                         (self.device))
            logger.error(e)
            return False

        try:
            self.executor._init_from_path(
                am=self.config.am,
                am_ckpt=self.config.am_ckpt,
                am_stat=self.config.am_stat,
                phones_dict=self.config.phones_dict,
                tones_dict=self.config.tones_dict,
                speaker_dict=self.config.speaker_dict,
                am_sample_rate=self.config.am_sample_rate,
                am_sess_conf=self.config.am_sess_conf,
                voc=self.config.voc,
                voc_ckpt=self.config.voc_ckpt,
                voc_sample_rate=self.config.voc_sample_rate,
                voc_sess_conf=self.config.voc_sess_conf,
                lang=self.config.lang)

        except Exception as e:
            logger.error("Failed to get model related files.")
            logger.error("Initialize TTS server engine Failed on device: %s." %
                         (self.config.voc_sess_conf.device))
            logger(e)
            return False

        logger.info("Initialize TTS server engine successfully on device: %s." %
                    (self.config.voc_sess_conf.device))

        return True


class PaddleTTSConnectionHandler:
    def __init__(self, tts_engine):
        """The PaddleSpeech TTS Server Connection Handler
           This connection process every tts server request
        Args:
            tts_engine (TTSEngine): The TTS engine
        """
        super().__init__()
        logger.info(
            "Create PaddleTTSConnectionHandler to process the tts request")

        self.tts_engine = tts_engine
        self.executor = self.tts_engine.executor
        self.config = self.tts_engine.config
        self.am_block = self.tts_engine.am_block
        self.am_pad = self.tts_engine.am_pad
        self.voc_block = self.tts_engine.voc_block
        self.voc_pad = self.tts_engine.voc_pad
        self.am_upsample = self.tts_engine.am_upsample
        self.voc_upsample = self.tts_engine.voc_upsample

L
lym0302 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
    def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
        """ 
        Streaming inference removes the result of pad inference
        """
        front_pad = min(chunk_id * block, pad)
        # first chunk
        if chunk_id == 0:
            data = data[:block * upsample]
        # last chunk
        elif chunk_id == chunk_num - 1:
            data = data[front_pad * upsample:]
        # middle chunk
        else:
            data = data[front_pad * upsample:(front_pad + block) * upsample]

        return data
L
lym0302 已提交
292 293

    @paddle.no_grad()
L
lym0302 已提交
294 295 296 297
    def infer(
            self,
            text: str,
            lang: str='zh',
L
lym0302 已提交
298
            am: str='fastspeech2_csmsc_onnx',
L
lym0302 已提交
299
            spk_id: int=0, ):
L
lym0302 已提交
300 301 302
        """
        Model inference and result stored in self.output.
        """
L
lym0302 已提交
303

L
lym0302 已提交
304 305
        # first_flag 用于标记首包
        first_flag = 1
L
lym0302 已提交
306 307
        get_tone_ids = False
        merge_sentences = False
L
lym0302 已提交
308 309

        # front 
L
lym0302 已提交
310 311
        frontend_st = time.time()
        if lang == 'zh':
L
lym0302 已提交
312
            input_ids = self.executor.frontend.get_input_ids(
L
lym0302 已提交
313 314 315 316 317 318 319
                text,
                merge_sentences=merge_sentences,
                get_tone_ids=get_tone_ids)
            phone_ids = input_ids["phone_ids"]
            if get_tone_ids:
                tone_ids = input_ids["tone_ids"]
        elif lang == 'en':
L
lym0302 已提交
320
            input_ids = self.executor.frontend.get_input_ids(
L
lym0302 已提交
321 322 323
                text, merge_sentences=merge_sentences)
            phone_ids = input_ids["phone_ids"]
        else:
L
lym0302 已提交
324
            logger.error("lang should in {'zh', 'en'}!")
L
lym0302 已提交
325 326
        frontend_et = time.time()
        self.frontend_time = frontend_et - frontend_st
L
lym0302 已提交
327 328

        for i in range(len(phone_ids)):
L
lym0302 已提交
329
            part_phone_ids = phone_ids[i].numpy()
L
lym0302 已提交
330 331 332
            voc_chunk_id = 0

            # fastspeech2_csmsc
L
lym0302 已提交
333
            if am == "fastspeech2_csmsc_onnx":
L
lym0302 已提交
334
                # am 
L
lym0302 已提交
335
                mel = self.executor.am_sess.run(
L
lym0302 已提交
336 337
                    output_names=None, input_feed={'text': part_phone_ids})
                mel = mel[0]
L
lym0302 已提交
338
                if first_flag == 1:
L
lym0302 已提交
339 340 341 342
                    first_am_et = time.time()
                    self.first_am_infer = first_am_et - frontend_et

                # voc streaming
L
lym0302 已提交
343 344
                mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad,
                                        "voc")
L
lym0302 已提交
345 346 347
                voc_chunk_num = len(mel_chunks)
                voc_st = time.time()
                for i, mel_chunk in enumerate(mel_chunks):
L
lym0302 已提交
348
                    sub_wav = self.executor.voc_sess.run(
L
lym0302 已提交
349 350
                        output_names=None, input_feed={'logmel': mel_chunk})
                    sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
L
lym0302 已提交
351 352
                                             self.voc_block, self.voc_pad,
                                             self.voc_upsample)
L
lym0302 已提交
353
                    if first_flag == 1:
L
lym0302 已提交
354 355 356
                        first_voc_et = time.time()
                        self.first_voc_infer = first_voc_et - first_am_et
                        self.first_response_time = first_voc_et - frontend_st
L
lym0302 已提交
357
                        first_flag = 0
L
lym0302 已提交
358 359 360 361

                    yield sub_wav

            # fastspeech2_cnndecoder_csmsc 
L
lym0302 已提交
362
            elif am == "fastspeech2_cnndecoder_csmsc_onnx":
L
lym0302 已提交
363
                # am 
L
lym0302 已提交
364
                orig_hs = self.executor.am_encoder_infer_sess.run(
L
lym0302 已提交
365 366
                    None, input_feed={'text': part_phone_ids})
                orig_hs = orig_hs[0]
L
lym0302 已提交
367 368 369 370 371 372 373 374 375 376 377

                # streaming voc chunk info
                mel_len = orig_hs.shape[1]
                voc_chunk_num = math.ceil(mel_len / self.voc_block)
                start = 0
                end = min(self.voc_block + self.voc_pad, mel_len)

                # streaming am
                hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
                am_chunk_num = len(hss)
                for i, hs in enumerate(hss):
L
lym0302 已提交
378
                    am_decoder_output = self.executor.am_decoder_sess.run(
L
lym0302 已提交
379
                        None, input_feed={'xs': hs})
L
lym0302 已提交
380
                    am_postnet_output = self.executor.am_postnet_sess.run(
L
lym0302 已提交
381 382 383 384 385 386 387 388
                        None,
                        input_feed={
                            'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
                        })
                    am_output_data = am_decoder_output + np.transpose(
                        am_postnet_output[0], (0, 2, 1))
                    normalized_mel = am_output_data[0][0]

L
lym0302 已提交
389 390 391 392 393
                    sub_mel = denorm(normalized_mel, self.executor.am_mu,
                                     self.executor.am_std)
                    sub_mel = self.depadding(sub_mel, am_chunk_num, i,
                                             self.am_block, self.am_pad,
                                             self.am_upsample)
L
lym0302 已提交
394 395 396 397 398 399 400 401

                    if i == 0:
                        mel_streaming = sub_mel
                    else:
                        mel_streaming = np.concatenate(
                            (mel_streaming, sub_mel), axis=0)

                    # streaming voc
L
lym0302 已提交
402
                    # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
L
lym0302 已提交
403 404
                    while (mel_streaming.shape[0] >= end and
                           voc_chunk_id < voc_chunk_num):
L
lym0302 已提交
405
                        if first_flag == 1:
L
lym0302 已提交
406 407 408 409
                            first_am_et = time.time()
                            self.first_am_infer = first_am_et - frontend_et
                        voc_chunk = mel_streaming[start:end, :]

L
lym0302 已提交
410
                        sub_wav = self.executor.voc_sess.run(
L
lym0302 已提交
411
                            output_names=None, input_feed={'logmel': voc_chunk})
L
lym0302 已提交
412 413 414
                        sub_wav = self.depadding(
                            sub_wav[0], voc_chunk_num, voc_chunk_id,
                            self.voc_block, self.voc_pad, self.voc_upsample)
L
lym0302 已提交
415
                        if first_flag == 1:
L
lym0302 已提交
416 417 418
                            first_voc_et = time.time()
                            self.first_voc_infer = first_voc_et - first_am_et
                            self.first_response_time = first_voc_et - frontend_st
L
lym0302 已提交
419
                            first_flag = 0
L
lym0302 已提交
420 421 422 423

                        yield sub_wav

                        voc_chunk_id += 1
L
lym0302 已提交
424 425 426 427 428
                        start = max(
                            0, voc_chunk_id * self.voc_block - self.voc_pad)
                        end = min(
                            (voc_chunk_id + 1) * self.voc_block + self.voc_pad,
                            mel_len)
L
lym0302 已提交
429

L
lym0302 已提交
430
            else:
L
lym0302 已提交
431 432 433 434 435
                logger.error(
                    "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
                )

        self.final_response_time = time.time() - frontend_st
L
lym0302 已提交
436 437 438 439

    def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
        # Convert byte to text
        if text_bese64:
L
lym0302 已提交
440 441
            text_bytes = base64.b64decode(text_bese64)  # base64 to bytes
        text = text_bytes.decode('utf-8')  # bytes to text
L
lym0302 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468

        return text

    def run(self,
            sentence: str,
            spk_id: int=0,
            speed: float=1.0,
            volume: float=1.0,
            sample_rate: int=0,
            save_path: str=None):
        """ run include inference and postprocess.

        Args:
            sentence (str): text to be synthesized
            spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
            speed (float, optional): speed. Defaults to 1.0.
            volume (float, optional): volume. Defaults to 1.0.
            sample_rate (int, optional): target sample rate for synthesized audio, 
            0 means the same as the model sampling rate. Defaults to 0.
            save_path (str, optional): The save path of the synthesized audio. 
            None means do not save audio. Defaults to None.

        Returns:
            wav_base64: The base64 format of the synthesized audio.
        """
        wav_list = []

L
lym0302 已提交
469
        for wav in self.infer(
L
lym0302 已提交
470
                text=sentence,
L
lym0302 已提交
471
                lang=self.config.lang,
L
lym0302 已提交
472
                am=self.config.am,
L
lym0302 已提交
473 474
                spk_id=spk_id, ):

L
lym0302 已提交
475
            # wav type: <class 'numpy.ndarray'>  float32, convert to pcm (base64)
L
lym0302 已提交
476
            wav = float2pcm(wav)  # float32 to int16
L
lym0302 已提交
477 478 479 480
            wav_bytes = wav.tobytes()  # to bytes
            wav_base64 = base64.b64encode(wav_bytes).decode('utf8')  # to base64
            wav_list.append(wav)

L
lym0302 已提交
481
            yield wav_base64
L
lym0302 已提交
482

L
lym0302 已提交
483
        wav_all = np.concatenate(wav_list, axis=0)
L
lym0302 已提交
484
        duration = len(wav_all) / self.config.voc_sample_rate
L
lym0302 已提交
485 486
        logger.info(f"sentence: {sentence}")
        logger.info(f"The durations of audio is: {duration} s")
L
lym0302 已提交
487 488 489
        logger.info(f"first response time: {self.first_response_time} s")
        logger.info(f"final response time: {self.final_response_time} s")
        logger.info(f"RTF: {self.final_response_time / duration}")
L
lym0302 已提交
490
        logger.info(
L
lym0302 已提交
491
            f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s,"
L
lym0302 已提交
492
        )