tts_engine.py 6.7 KB
Newer Older
L
lym0302 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
L
liangym 已提交
15
import io
L
lym0302 已提交
16 17 18

import librosa
import numpy as np
L
lym0302 已提交
19
import paddle
L
lym0302 已提交
20
import soundfile as sf
L
liangym 已提交
21
from scipy.io import wavfile
L
lym0302 已提交
22 23 24

from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
L
lym0302 已提交
25
from paddlespeech.server.engine.base_engine import BaseEngine
L
lym0302 已提交
26
from paddlespeech.server.utils.audio_process import change_speed
L
lym0302 已提交
27 28 29
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
L
lym0302 已提交
30 31 32 33 34 35 36

__all__ = ['TTSEngine']


class TTSServerExecutor(TTSExecutor):
    def __init__(self):
        super().__init__()
L
lym0302 已提交
37
        pass
L
lym0302 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51


class TTSEngine(BaseEngine):
    """TTS server engine

    Args:
        metaclass: Defaults to Singleton.
    """

    def __init__(self, name=None):
        """Initialize TTS server engine
        """
        super(TTSEngine, self).__init__()

L
lym0302 已提交
52
    def init(self, config_file: str) -> bool:
L
lym0302 已提交
53
        self.executor = TTSServerExecutor()
L
lym0302 已提交
54

L
lym0302 已提交
55 56
        try:
            self.config = get_config(config_file)
L
lym0302 已提交
57 58 59 60
            if self.config.device is None:
                paddle.set_device(paddle.get_device())
            else:
                paddle.set_device(self.config.device)
L
lym0302 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

            self.executor._init_from_path(
                am=self.config.am,
                am_config=self.config.am_config,
                am_ckpt=self.config.am_ckpt,
                am_stat=self.config.am_stat,
                phones_dict=self.config.phones_dict,
                tones_dict=self.config.tones_dict,
                speaker_dict=self.config.speaker_dict,
                voc=self.config.voc,
                voc_config=self.config.voc_config,
                voc_ckpt=self.config.voc_ckpt,
                voc_stat=self.config.voc_stat,
                lang=self.config.lang)
        except:
            logger.info("Initialize TTS server engine Failed.")
            return False
L
lym0302 已提交
78 79

        logger.info("Initialize TTS server engine successfully.")
L
lym0302 已提交
80
        return True
L
lym0302 已提交
81 82 83 84 85 86 87

    def postprocess(self,
                    wav,
                    original_fs: int,
                    target_fs: int=16000,
                    volume: float=1.0,
                    speed: float=1.0,
L
lym0302 已提交
88
                    audio_path: str=None):
L
lym0302 已提交
89 90 91 92 93 94 95 96
        """Post-processing operations, including speech, volume, sample rate, save audio file

        Args:
            wav (numpy(float)): Synthesized audio sample points
            original_fs (int): original audio sample rate
            target_fs (int): target audio sample rate
            volume (float): target volume
            speed (float): target speed
L
lym0302 已提交
97 98 99 100 101 102 103

        Raises:
            ServerBaseException: Throws an exception if the change speed unsuccessfully.

        Returns:
            target_fs: target sample rate for synthesized audio.
            wav_base64: The base64 format of the synthesized audio.
L
lym0302 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117
        """

        # transform sample_rate
        if target_fs == 0 or target_fs > original_fs:
            target_fs = original_fs
            wav_tar_fs = wav
        else:
            wav_tar_fs = librosa.resample(
                np.squeeze(wav), original_fs, target_fs)

        # transform volume
        wav_vol = wav_tar_fs * volume

        # transform speed
L
liangym 已提交
118 119 120 121 122 123
        try:  # windows not support soxbindings
            wav_speed = change_speed(wav_vol, speed, target_fs)
        except:
            raise ServerBaseException(
                ErrorCode.SERVER_INTERNAL_ERR,
                "Can not install soxbindings on your system.")
L
lym0302 已提交
124 125

        # wav to base64
L
liangym 已提交
126 127 128 129
        buf = io.BytesIO()
        wavfile.write(buf, target_fs, wav_speed)
        base64_bytes = base64.b64encode(buf.read())
        wav_base64 = base64_bytes.decode('utf-8')
L
lym0302 已提交
130 131 132

        # save audio
        if audio_path is not None and audio_path.endswith(".wav"):
L
liangym 已提交
133
            sf.write(audio_path, wav_speed, target_fs)
L
lym0302 已提交
134
        elif audio_path is not None and audio_path.endswith(".pcm"):
L
liangym 已提交
135 136 137 138
            wav_norm = wav_speed * (32767 / max(0.001,
                                                np.max(np.abs(wav_speed))))
            with open(audio_path, "wb") as f:
                f.write(wav_norm.astype(np.int16))
L
lym0302 已提交
139 140 141 142 143 144 145 146 147

        return target_fs, wav_base64

    def run(self,
            sentence: str,
            spk_id: int=0,
            speed: float=1.0,
            volume: float=1.0,
            sample_rate: int=0,
L
lym0302 已提交
148
            save_path: str=None):
L
lym0302 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        """ run include inference and postprocess.

        Args:
            sentence (str): text to be synthesized
            spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
            speed (float, optional): speed. Defaults to 1.0.
            volume (float, optional): volume. Defaults to 1.0.
            sample_rate (int, optional): target sample rate for synthesized audio, 
            0 means the same as the model sampling rate. Defaults to 0.
            save_path (str, optional): The save path of the synthesized audio. 
            None means do not save audio. Defaults to None.

        Raises:
            ServerBaseException: Throws an exception if tts inference unsuccessfully.
            ServerBaseException: Throws an exception if postprocess unsuccessfully.

        Returns:
            lang: model language 
            target_sample_rate: target sample rate for synthesized audio.
            wav_base64: The base64 format of the synthesized audio.
        """
L
lym0302 已提交
170

L
lym0302 已提交
171
        lang = self.config.lang
L
lym0302 已提交
172

L
lym0302 已提交
173 174
        try:
            self.executor.infer(
L
lym0302 已提交
175
                text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
L
lym0302 已提交
176 177 178 179 180 181 182 183 184 185 186
        except:
            raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
                                      "tts infer failed.")

        try:
            target_sample_rate, wav_base64 = self.postprocess(
                wav=self.executor._outputs['wav'].numpy(),
                original_fs=self.executor.am_config.fs,
                target_fs=sample_rate,
                volume=volume,
                speed=speed,
L
lym0302 已提交
187
                audio_path=save_path)
L
lym0302 已提交
188 189 190
        except:
            raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
                                      "tts postprocess failed.")
L
lym0302 已提交
191 192

        return lang, target_sample_rate, wav_base64