utils.py 8.4 KB
Newer Older
P
pfZhu 已提交
1 2 3
import numpy as np
import paddle
import yaml
小湉湉's avatar
小湉湉 已提交
4
from sedit_arg_parser import parse_args
P
pfZhu 已提交
5 6
from yacs.config import CfgNode

小湉湉's avatar
小湉湉 已提交
7 8
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
P
pfZhu 已提交
9 10
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.modules.normalizer import ZScore
O
oyjxer 已提交
11
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder
小湉湉's avatar
小湉湉 已提交
12
# new add
P
pfZhu 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

model_alias = {
    # acoustic model
    "speedyspeech":
    "paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
    "speedyspeech_inference":
    "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
    "fastspeech2":
    "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
    "fastspeech2_inference":
    "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
    "tacotron2":
    "paddlespeech.t2s.models.tacotron2:Tacotron2",
    "tacotron2_inference":
    "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
}


O
oyjxer 已提交
31 32 33 34 35 36 37 38
def is_chinese(ch):
    if u'\u4e00' <= ch <= u'\u9fff':
        return True
    else:
        return False


def build_vocoder_from_file(
小湉湉's avatar
小湉湉 已提交
39 40 41 42
        vocoder_config_file=None,
        vocoder_file=None,
        model=None,
        device="cpu", ):
O
oyjxer 已提交
43 44 45
    # Build vocoder
    if str(vocoder_file).endswith(".pkl"):
        # If the extension is ".pkl", the model is trained with parallel_wavegan
小湉湉's avatar
小湉湉 已提交
46 47
        vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file,
                                                   vocoder_config_file)
O
oyjxer 已提交
48 49 50 51 52 53
        return vocoder.to(device)

    else:
        raise ValueError(f"{vocoder_file} is not supported format.")


P
pfZhu 已提交
54 55 56 57 58
def get_voc_out(mel, target_language="chinese"):
    # vocoder
    args = parse_args()

    assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..."
小湉湉's avatar
小湉湉 已提交
59

O
oyjxer 已提交
60
    # print("current vocoder: ", args.voc)
P
pfZhu 已提交
61 62
    with open(args.voc_config) as f:
        voc_config = CfgNode(yaml.safe_load(f))
P
pfZhu 已提交
63
    # print(voc_config)
P
pfZhu 已提交
64 65 66 67

    voc_inference = get_voc_inference(args, voc_config)

    mel = paddle.to_tensor(mel)
P
pfZhu 已提交
68
    # print("masked_mel: ", mel.shape)
P
pfZhu 已提交
69 70
    with paddle.no_grad():
        wav = voc_inference(mel)
P
pfZhu 已提交
71
    # print("shepe of wav (time x n_channels):%s"%wav.shape)   
P
pfZhu 已提交
72 73
    return np.squeeze(wav)

小湉湉's avatar
小湉湉 已提交
74

P
pfZhu 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
# dygraph
def get_am_inference(args, am_config):
    with open(args.phones_dict, "r") as f:
        phn_id = [line.strip().split() for line in f.readlines()]
    vocab_size = len(phn_id)
    # print("vocab_size:", vocab_size)

    tone_size = None
    if 'tones_dict' in args and args.tones_dict:
        with open(args.tones_dict, "r") as f:
            tone_id = [line.strip().split() for line in f.readlines()]
        tone_size = len(tone_id)
        print("tone_size:", tone_size)

    spk_num = None
    if 'speaker_dict' in args and args.speaker_dict:
        with open(args.speaker_dict, 'rt') as f:
            spk_id = [line.strip().split() for line in f.readlines()]
        spk_num = len(spk_id)
        print("spk_num:", spk_num)

    odim = am_config.n_mels
    # model: {model_name}_{dataset}
    am_name = args.am[:args.am.rindex('_')]
    am_dataset = args.am[args.am.rindex('_') + 1:]

    am_class = dynamic_import(am_name, model_alias)
    am_inference_class = dynamic_import(am_name + '_inference', model_alias)

    if am_name == 'fastspeech2':
        am = am_class(
            idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
    elif am_name == 'speedyspeech':
        am = am_class(
            vocab_size=vocab_size,
            tone_size=tone_size,
            spk_num=spk_num,
            **am_config["model"])
    elif am_name == 'tacotron2':
        am = am_class(idim=vocab_size, odim=odim, **am_config["model"])

    am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
    am.eval()
    am_mu, am_std = np.load(args.am_stat)
    am_mu = paddle.to_tensor(am_mu)
    am_std = paddle.to_tensor(am_std)
    am_normalizer = ZScore(am_mu, am_std)
    am_inference = am_inference_class(am_normalizer, am)
    am_inference.eval()
    print("acoustic model done!")
    return am, am_inference, am_name, am_dataset, phn_id


小湉湉's avatar
小湉湉 已提交
128 129 130 131
def evaluate_durations(phns,
                       target_language="chinese",
                       fs=24000,
                       hop_length=300):
P
pfZhu 已提交
132
    args = parse_args()
O
oyjxer 已提交
133 134

    if target_language == 'english':
小湉湉's avatar
小湉湉 已提交
135
        args.lang = 'en'
O
oyjxer 已提交
136 137 138 139 140 141 142
        args.am = "fastspeech2_ljspeech"
        args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
        args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
        args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
        args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"

    elif target_language == 'chinese':
小湉湉's avatar
小湉湉 已提交
143
        args.lang = 'zh'
O
oyjxer 已提交
144
        args.am = "fastspeech2_csmsc"
小湉湉's avatar
小湉湉 已提交
145
        args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
O
oyjxer 已提交
146 147
        args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
        args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
小湉湉's avatar
小湉湉 已提交
148
        args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
O
oyjxer 已提交
149

P
pfZhu 已提交
150
    # args = parser.parse_args(args=[])
P
pfZhu 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    if args.ngpu == 0:
        paddle.set_device("cpu")
    elif args.ngpu > 0:
        paddle.set_device("gpu")
    else:
        print("ngpu should >= 0 !")

    assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..."

    # Init body.
    with open(args.am_config) as f:
        am_config = CfgNode(yaml.safe_load(f))
    # print("========Config========")
    # print(am_config)
    # print("---------------------")
    # acoustic model
小湉湉's avatar
小湉湉 已提交
167 168
    am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args,
                                                                     am_config)
P
pfZhu 已提交
169

P
pfZhu 已提交
170 171 172 173 174 175
    torch_phns = phns
    vocab_phones = {}
    for tone, id in phn_id:
        vocab_phones[tone] = int(id)
    # print("vocab_phones: ", len(vocab_phones))
    vocab_size = len(vocab_phones)
小湉湉's avatar
小湉湉 已提交
176
    phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
O
oyjxer 已提交
177

P
pfZhu 已提交
178 179
    phone_ids = [vocab_phones[item] for item in phonemes]
    phone_ids_new = phone_ids
小湉湉's avatar
小湉湉 已提交
180
    phone_ids_new.append(vocab_size - 1)
P
pfZhu 已提交
181
    phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
小湉湉's avatar
小湉湉 已提交
182 183
    normalized_mel, d_outs, p_outs, e_outs = am.inference(
        phone_ids_new, spk_id=None, spk_emb=None)
P
pfZhu 已提交
184 185 186 187 188 189 190 191 192
    pre_d_outs = d_outs
    phoneme_durations_new = pre_d_outs * hop_length / fs
    phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
    return phoneme_durations_new


def sentence2phns(sentence, target_language="en"):
    args = parse_args()
    if target_language == 'en':
小湉湉's avatar
小湉湉 已提交
193
        args.lang = 'en'
P
pfZhu 已提交
194 195
        args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
    elif target_language == 'zh':
小湉湉's avatar
小湉湉 已提交
196 197
        args.lang = 'zh'
        args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
P
pfZhu 已提交
198 199
    else:
        print("target_language should in {'zh', 'en'}!")
小湉湉's avatar
小湉湉 已提交
200

P
pfZhu 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213
    frontend = get_frontend(args)
    merge_sentences = True
    get_tone_ids = False

    if target_language == 'zh':
        input_ids = frontend.get_input_ids(
            sentence,
            merge_sentences=merge_sentences,
            get_tone_ids=get_tone_ids,
            print_info=False)
        phone_ids = input_ids["phone_ids"]

        phonemes = frontend.get_phonemes(
小湉湉's avatar
小湉湉 已提交
214 215
            sentence, merge_sentences=merge_sentences, print_info=False)

P
pfZhu 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        return phonemes[0], input_ids["phone_ids"][0]

    elif target_language == 'en':
        phonemes = frontend.phoneticize(sentence)
        input_ids = frontend.get_input_ids(
            sentence, merge_sentences=merge_sentences)
        phone_ids = input_ids["phone_ids"]

        phones_list = []
        vocab_phones = {}
        punc = ":,;。?!“”‘’':,;.?!"
        with open(args.phones_dict, 'rt') as f:
            phn_id = [line.strip().split() for line in f.readlines()]
        for phn, id in phn_id:
            vocab_phones[phn] = int(id)

        phones = phonemes[1:-1]
        phones = [phn for phn in phones if not phn.isspace()]
        # replace unk phone with sp
        phones = [
小湉湉's avatar
小湉湉 已提交
236
            phn if (phn in vocab_phones and phn not in punc) else "sp"
P
pfZhu 已提交
237 238 239
            for phn in phones
        ]
        phones_list.append(phones)
小湉湉's avatar
小湉湉 已提交
240
        return phones_list[0], input_ids["phone_ids"][0]
P
pfZhu 已提交
241 242 243

    else:
        print("lang should in {'zh', 'en'}!")