utils.py 7.2 KB
Newer Older
小湉湉's avatar
小湉湉 已提交
1 2 3 4
import os
from typing import List
from typing import Optional

P
pfZhu 已提交
5 6 7
import numpy as np
import paddle
import yaml
小湉湉's avatar
小湉湉 已提交
8
from sedit_arg_parser import parse_args
P
pfZhu 已提交
9 10
from yacs.config import CfgNode

小湉湉's avatar
小湉湉 已提交
11
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
P
pfZhu 已提交
12
from paddlespeech.t2s.modules.normalizer import ZScore
小湉湉's avatar
小湉湉 已提交
13
from tools.torch_pwgan import TorchPWGAN
P
pfZhu 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28

model_alias = {
    # acoustic model
    "speedyspeech":
    "paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
    "speedyspeech_inference":
    "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
    "fastspeech2":
    "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
    "fastspeech2_inference":
    "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
    "tacotron2":
    "paddlespeech.t2s.models.tacotron2:Tacotron2",
    "tacotron2_inference":
    "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
小湉湉's avatar
小湉湉 已提交
29 30 31 32
    "pwgan":
    "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
    "pwgan_inference":
    "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
P
pfZhu 已提交
33 34 35
}


O
oyjxer 已提交
36 37 38 39 40 41 42 43
def is_chinese(ch):
    if u'\u4e00' <= ch <= u'\u9fff':
        return True
    else:
        return False


def build_vocoder_from_file(
小湉湉's avatar
小湉湉 已提交
44 45 46 47
        vocoder_config_file=None,
        vocoder_file=None,
        model=None,
        device="cpu", ):
O
oyjxer 已提交
48 49 50
    # Build vocoder
    if str(vocoder_file).endswith(".pkl"):
        # If the extension is ".pkl", the model is trained with parallel_wavegan
小湉湉's avatar
小湉湉 已提交
51
        vocoder = TorchPWGAN(vocoder_file, vocoder_config_file)
O
oyjxer 已提交
52 53 54 55 56 57
        return vocoder.to(device)

    else:
        raise ValueError(f"{vocoder_file} is not supported format.")


小湉湉's avatar
小湉湉 已提交
58
def get_voc_out(mel, target_lang: str="chinese"):
P
pfZhu 已提交
59 60 61
    # vocoder
    args = parse_args()

小湉湉's avatar
小湉湉 已提交
62
    assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
小湉湉's avatar
小湉湉 已提交
63

O
oyjxer 已提交
64
    # print("current vocoder: ", args.voc)
P
pfZhu 已提交
65 66
    with open(args.voc_config) as f:
        voc_config = CfgNode(yaml.safe_load(f))
小湉湉's avatar
小湉湉 已提交
67 68 69 70 71
    voc_inference = voc_inference = get_voc_inference(
        voc=args.voc,
        voc_config=voc_config,
        voc_ckpt=args.voc_ckpt,
        voc_stat=args.voc_stat)
P
pfZhu 已提交
72 73 74 75 76

    with paddle.no_grad():
        wav = voc_inference(mel)
    return np.squeeze(wav)

小湉湉's avatar
小湉湉 已提交
77

P
pfZhu 已提交
78
# dygraph
小湉湉's avatar
小湉湉 已提交
79 80 81 82 83 84 85 86 87
def get_am_inference(am: str='fastspeech2_csmsc',
                     am_config: CfgNode=None,
                     am_ckpt: Optional[os.PathLike]=None,
                     am_stat: Optional[os.PathLike]=None,
                     phones_dict: Optional[os.PathLike]=None,
                     tones_dict: Optional[os.PathLike]=None,
                     speaker_dict: Optional[os.PathLike]=None,
                     return_am: bool=False):
    with open(phones_dict, "r") as f:
P
pfZhu 已提交
88 89
        phn_id = [line.strip().split() for line in f.readlines()]
    vocab_size = len(phn_id)
小湉湉's avatar
小湉湉 已提交
90
    print("vocab_size:", vocab_size)
P
pfZhu 已提交
91 92

    tone_size = None
小湉湉's avatar
小湉湉 已提交
93 94
    if tones_dict is not None:
        with open(tones_dict, "r") as f:
P
pfZhu 已提交
95 96 97 98 99
            tone_id = [line.strip().split() for line in f.readlines()]
        tone_size = len(tone_id)
        print("tone_size:", tone_size)

    spk_num = None
小湉湉's avatar
小湉湉 已提交
100 101
    if speaker_dict is not None:
        with open(speaker_dict, 'rt') as f:
P
pfZhu 已提交
102 103 104 105 106 107
            spk_id = [line.strip().split() for line in f.readlines()]
        spk_num = len(spk_id)
        print("spk_num:", spk_num)

    odim = am_config.n_mels
    # model: {model_name}_{dataset}
小湉湉's avatar
小湉湉 已提交
108 109
    am_name = am[:am.rindex('_')]
    am_dataset = am[am.rindex('_') + 1:]
P
pfZhu 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

    am_class = dynamic_import(am_name, model_alias)
    am_inference_class = dynamic_import(am_name + '_inference', model_alias)

    if am_name == 'fastspeech2':
        am = am_class(
            idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
    elif am_name == 'speedyspeech':
        am = am_class(
            vocab_size=vocab_size,
            tone_size=tone_size,
            spk_num=spk_num,
            **am_config["model"])
    elif am_name == 'tacotron2':
        am = am_class(idim=vocab_size, odim=odim, **am_config["model"])

小湉湉's avatar
小湉湉 已提交
126
    am.set_state_dict(paddle.load(am_ckpt)["main_params"])
P
pfZhu 已提交
127
    am.eval()
小湉湉's avatar
小湉湉 已提交
128
    am_mu, am_std = np.load(am_stat)
P
pfZhu 已提交
129 130 131 132 133 134
    am_mu = paddle.to_tensor(am_mu)
    am_std = paddle.to_tensor(am_std)
    am_normalizer = ZScore(am_mu, am_std)
    am_inference = am_inference_class(am_normalizer, am)
    am_inference.eval()
    print("acoustic model done!")
小湉湉's avatar
小湉湉 已提交
135 136 137 138
    if return_am:
        return am_inference, am
    else:
        return am_inference
P
pfZhu 已提交
139 140


小湉湉's avatar
小湉湉 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
def get_voc_inference(
        voc: str='pwgan_csmsc',
        voc_config: Optional[os.PathLike]=None,
        voc_ckpt: Optional[os.PathLike]=None,
        voc_stat: Optional[os.PathLike]=None, ):
    # model: {model_name}_{dataset}
    voc_name = voc[:voc.rindex('_')]
    voc_class = dynamic_import(voc_name, model_alias)
    voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
    if voc_name != 'wavernn':
        voc = voc_class(**voc_config["generator_params"])
        voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
        voc.remove_weight_norm()
        voc.eval()
    else:
        voc = voc_class(**voc_config["model"])
        voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
        voc.eval()

    voc_mu, voc_std = np.load(voc_stat)
    voc_mu = paddle.to_tensor(voc_mu)
    voc_std = paddle.to_tensor(voc_std)
    voc_normalizer = ZScore(voc_mu, voc_std)
    voc_inference = voc_inference_class(voc_normalizer, voc)
    voc_inference.eval()
    print("voc done!")
    return voc_inference


def evaluate_durations(phns: List[str],
                       target_lang: str="chinese",
                       fs: int=24000,
                       hop_length: int=300):
P
pfZhu 已提交
174
    args = parse_args()
O
oyjxer 已提交
175

小湉湉's avatar
小湉湉 已提交
176
    if target_lang == 'english':
小湉湉's avatar
小湉湉 已提交
177
        args.lang = 'en'
O
oyjxer 已提交
178

小湉湉's avatar
小湉湉 已提交
179
    elif target_lang == 'chinese':
小湉湉's avatar
小湉湉 已提交
180
        args.lang = 'zh'
O
oyjxer 已提交
181

P
pfZhu 已提交
182
    # args = parser.parse_args(args=[])
P
pfZhu 已提交
183 184 185 186 187 188 189
    if args.ngpu == 0:
        paddle.set_device("cpu")
    elif args.ngpu > 0:
        paddle.set_device("gpu")
    else:
        print("ngpu should >= 0 !")

小湉湉's avatar
小湉湉 已提交
190
    assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
P
pfZhu 已提交
191 192 193 194

    # Init body.
    with open(args.am_config) as f:
        am_config = CfgNode(yaml.safe_load(f))
小湉湉's avatar
小湉湉 已提交
195 196 197 198 199 200 201 202 203 204

    am_inference, am = get_am_inference(
        am=args.am,
        am_config=am_config,
        am_ckpt=args.am_ckpt,
        am_stat=args.am_stat,
        phones_dict=args.phones_dict,
        tones_dict=args.tones_dict,
        speaker_dict=args.speaker_dict,
        return_am=True)
P
pfZhu 已提交
205

P
pfZhu 已提交
206 207
    torch_phns = phns
    vocab_phones = {}
小湉湉's avatar
小湉湉 已提交
208 209
    with open(args.phones_dict, "r") as f:
        phn_id = [line.strip().split() for line in f.readlines()]
P
pfZhu 已提交
210 211 212
    for tone, id in phn_id:
        vocab_phones[tone] = int(id)
    vocab_size = len(vocab_phones)
小湉湉's avatar
小湉湉 已提交
213
    phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
O
oyjxer 已提交
214

P
pfZhu 已提交
215 216
    phone_ids = [vocab_phones[item] for item in phonemes]
    phone_ids_new = phone_ids
小湉湉's avatar
小湉湉 已提交
217
    phone_ids_new.append(vocab_size - 1)
P
pfZhu 已提交
218
    phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
小湉湉's avatar
小湉湉 已提交
219 220
    normalized_mel, d_outs, p_outs, e_outs = am.inference(
        phone_ids_new, spk_id=None, spk_emb=None)
P
pfZhu 已提交
221 222 223 224
    pre_d_outs = d_outs
    phoneme_durations_new = pre_d_outs * hop_length / fs
    phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
    return phoneme_durations_new