inference.py 37.2 KB
Newer Older
P
pfZhu 已提交
1 2
#!/usr/bin/env python3
import argparse
小湉湉's avatar
小湉湉 已提交
3 4 5
import os
import random
from pathlib import Path
P
pfZhu 已提交
6 7 8 9 10
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
小湉湉's avatar
小湉湉 已提交
11 12 13

import librosa
import numpy as np
O
oyjxer 已提交
14
import paddle
小湉湉's avatar
小湉湉 已提交
15
import soundfile as sf
O
oyjxer 已提交
16
import torch
小湉湉's avatar
小湉湉 已提交
17
from paddle import nn
小湉湉's avatar
小湉湉 已提交
18

O
oyjxer 已提交
19
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
P
pfZhu 已提交
20

小湉湉's avatar
小湉湉 已提交
21 22 23 24
from align import alignment
from align import alignment_zh
from dataset import get_seg_pos
from dataset import get_seg_pos_reduce_duration
小湉湉's avatar
小湉湉 已提交
25 26
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
小湉湉's avatar
小湉湉 已提交
27
from dataset import phones_text_masking
小湉湉's avatar
小湉湉 已提交
28 29 30 31 32 33 34 35 36
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
小湉湉's avatar
小湉湉 已提交
37 38
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
P
pfZhu 已提交
39 40 41
random.seed(0)
np.random.seed(0)

O
oyjxer 已提交
42 43 44 45 46
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'


小湉湉's avatar
小湉湉 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def plot_mel_and_vocode_wav(uid: str,
                            wav_path: str,
                            prefix: str="./prompt/dev/",
                            source_lang: str='english',
                            target_lang: str='english',
                            model_name: str="conformer",
                            full_origin_str: str="",
                            old_str: str="",
                            new_str: str="",
                            duration_preditor_path: str=None,
                            use_pt_vocoder: bool=False,
                            sid: str=None,
                            non_autoreg: bool=True):
    wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
        uid=uid,
        prefix=prefix,
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        duration_preditor_path=duration_preditor_path,
小湉湉's avatar
小湉湉 已提交
70 71 72
        use_teacher_forcing=non_autoreg,
        sid=sid)

小湉湉's avatar
小湉湉 已提交
73
    masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
小湉湉's avatar
小湉湉 已提交
74

小湉湉's avatar
小湉湉 已提交
75
    if target_lang == 'english':
O
oyjxer 已提交
76
        if use_pt_vocoder:
小湉湉's avatar
小湉湉 已提交
77
            output_feat = output_feat.cpu().numpy()
小湉湉's avatar
小湉湉 已提交
78
            output_feat = torch.tensor(output_feat, dtype=torch.float)
O
oyjxer 已提交
79
            vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
小湉湉's avatar
小湉湉 已提交
80
            replaced_wav = vocoder(output_feat).cpu().numpy()
O
oyjxer 已提交
81
        else:
小湉湉's avatar
小湉湉 已提交
82
            replaced_wav = get_voc_out(output_feat, target_lang)
P
pfZhu 已提交
83

小湉湉's avatar
小湉湉 已提交
84 85
    elif target_lang == 'chinese':
        replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang)
小湉湉's avatar
小湉湉 已提交
86

小湉湉's avatar
小湉湉 已提交
87 88
    old_time_bdy = [hop_length * x for x in old_span_bdy]
    new_time_bdy = [hop_length * x for x in new_span_bdy]
小湉湉's avatar
小湉湉 已提交
89

小湉湉's avatar
小湉湉 已提交
90
    if target_lang == 'english':
小湉湉's avatar
小湉湉 已提交
91
        wav_org_replaced_paddle_voc = np.concatenate([
小湉湉's avatar
小湉湉 已提交
92 93 94
            wav_org[:old_time_bdy[0]],
            replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
            wav_org[old_time_bdy[1]:]
小湉湉's avatar
小湉湉 已提交
95
        ])
P
pfZhu 已提交
96

小湉湉's avatar
小湉湉 已提交
97
        data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
P
pfZhu 已提交
98

小湉湉's avatar
小湉湉 已提交
99
    elif target_lang == 'chinese':
小湉湉's avatar
小湉湉 已提交
100
        wav_org_replaced_only_mask_fst2_voc = np.concatenate([
小湉湉's avatar
小湉湉 已提交
101 102
            wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
            wav_org[old_time_bdy[1]:]
小湉湉's avatar
小湉湉 已提交
103
        ])
P
pfZhu 已提交
104
        data_dict = {
小湉湉's avatar
小湉湉 已提交
105 106 107
            "origin": wav_org,
            "output": wav_org_replaced_only_mask_fst2_voc,
        }
P
pfZhu 已提交
108

小湉湉's avatar
小湉湉 已提交
109
    return data_dict, old_span_bdy
P
pfZhu 已提交
110

O
oyjxer 已提交
111

小湉湉's avatar
小湉湉 已提交
112
def get_unk_phns(word_str: str):
O
oyjxer 已提交
113 114 115 116
    tmpbase = '/tmp/tp.'
    f = open(tmpbase + 'temp.words', 'w')
    f.write(word_str)
    f.close()
小湉湉's avatar
小湉湉 已提交
117 118
    os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
              'temp.phons')
O
oyjxer 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    f = open(tmpbase + 'temp.phons', 'r')
    lines2 = f.readline().strip().split()
    f.close()
    phns = []
    for phn in lines2:
        phons = phn.replace('\n', '').replace(' ', '')
        seq = []
        j = 0
        while (j < len(phons)):
            if (phons[j] > 'Z'):
                if (phons[j] == 'j'):
                    seq.append('JH')
                elif (phons[j] == 'h'):
                    seq.append('HH')
                else:
                    seq.append(phons[j].upper())
                j += 1
            else:
小湉湉's avatar
小湉湉 已提交
137
                p = phons[j:j + 2]
O
oyjxer 已提交
138 139 140 141 142 143 144 145 146 147 148 149
                if (p == 'WH'):
                    seq.append('W')
                elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
                    seq.append(p)
                elif (p == 'AX'):
                    seq.append('AH0')
                else:
                    seq.append(p + '1')
                j += 2
        phns.extend(seq)
    return phns

小湉湉's avatar
小湉湉 已提交
150

小湉湉's avatar
小湉湉 已提交
151
def words2phns(line: str):
小湉湉's avatar
小湉湉 已提交
152
    dictfile = MODEL_DIR_EN + '/dict'
O
oyjxer 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    line = line.strip()
    words = []
    for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
        line = line.replace(pun, ' ')
    for wrd in line.split():
        if (wrd[-1] == '-'):
            wrd = wrd[:-1]
        if (wrd[0] == "'"):
            wrd = wrd[1:]
        if wrd:
            words.append(wrd)
    ds = set([])
    word2phns_dict = {}
    with open(dictfile, 'r') as fid:
        for line in fid:
            word = line.split()[0]
            ds.add(word)
            if word not in word2phns_dict.keys():
                word2phns_dict[word] = " ".join(line.split()[1:])
小湉湉's avatar
小湉湉 已提交
172

O
oyjxer 已提交
173 174 175 176
    phns = []
    wrd2phns = {}
    for index, wrd in enumerate(words):
        if wrd == '[MASK]':
小湉湉's avatar
小湉湉 已提交
177
            wrd2phns[str(index) + "_" + wrd] = [wrd]
O
oyjxer 已提交
178 179
            phns.append(wrd)
        elif (wrd.upper() not in ds):
小湉湉's avatar
小湉湉 已提交
180
            wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
O
oyjxer 已提交
181 182
            phns.extend(get_unk_phns(wrd))
        else:
小湉湉's avatar
小湉湉 已提交
183 184
            wrd2phns[str(index) +
                     "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
O
oyjxer 已提交
185 186 187 188 189
            phns.extend(word2phns_dict[wrd.upper()].split())

    return phns, wrd2phns


小湉湉's avatar
小湉湉 已提交
190
def words2phns_zh(line: str):
小湉湉's avatar
小湉湉 已提交
191
    dictfile = MODEL_DIR_ZH + '/dict'
O
oyjxer 已提交
192 193
    line = line.strip()
    words = []
小湉湉's avatar
小湉湉 已提交
194 195 196 197
    for pun in [
            ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
            u'。', u':', u';', u'!', u'?', u'(', u')'
    ]:
O
oyjxer 已提交
198 199 200 201 202 203 204 205
        line = line.replace(pun, ' ')
    for wrd in line.split():
        if (wrd[-1] == '-'):
            wrd = wrd[:-1]
        if (wrd[0] == "'"):
            wrd = wrd[1:]
        if wrd:
            words.append(wrd)
小湉湉's avatar
小湉湉 已提交
206

O
oyjxer 已提交
207 208 209 210 211 212 213 214
    ds = set([])
    word2phns_dict = {}
    with open(dictfile, 'r') as fid:
        for line in fid:
            word = line.split()[0]
            ds.add(word)
            if word not in word2phns_dict.keys():
                word2phns_dict[word] = " ".join(line.split()[1:])
小湉湉's avatar
小湉湉 已提交
215

O
oyjxer 已提交
216 217 218 219
    phns = []
    wrd2phns = {}
    for index, wrd in enumerate(words):
        if wrd == '[MASK]':
小湉湉's avatar
小湉湉 已提交
220
            wrd2phns[str(index) + "_" + wrd] = [wrd]
O
oyjxer 已提交
221 222 223 224
            phns.append(wrd)
        elif (wrd.upper() not in ds):
            print("出现非法词错误,请输入正确的文本...")
        else:
小湉湉's avatar
小湉湉 已提交
225
            wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
O
oyjxer 已提交
226 227 228 229 230
            phns.extend(word2phns_dict[wrd].split())

    return phns, wrd2phns


小湉湉's avatar
小湉湉 已提交
231
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
O
oyjxer 已提交
232 233 234
    vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
    vocoder_file = download_pretrained_model(vocoder_tag)
    vocoder_config = Path(vocoder_file).parent / "config.yml"
小湉湉's avatar
小湉湉 已提交
235
    vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
O
oyjxer 已提交
236 237
    return vocoder

小湉湉's avatar
小湉湉 已提交
238

小湉湉's avatar
小湉湉 已提交
239
def load_model(model_name: str):
小湉湉's avatar
小湉湉 已提交
240
    config_path = './pretrained_model/{}/config.yaml'.format(model_name)
P
pfZhu 已提交
241
    model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
小湉湉's avatar
小湉湉 已提交
242 243
    mlm_model, args = build_model_from_file(
        config_file=config_path, model_file=model_path)
P
pfZhu 已提交
244 245 246
    return mlm_model, args


小湉湉's avatar
小湉湉 已提交
247
def read_data(uid: str, prefix: str):
小湉湉's avatar
小湉湉 已提交
248 249
    mfa_text = read_2column_text(prefix + '/text')[uid]
    mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
P
pfZhu 已提交
250 251 252
    if 'mnt' not in mfa_wav_path:
        mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path
    return mfa_text, mfa_wav_path
小湉湉's avatar
小湉湉 已提交
253 254


小湉湉's avatar
小湉湉 已提交
255
def get_align_data(uid: str, prefix: str):
小湉湉's avatar
小湉湉 已提交
256 257 258 259 260 261 262
    mfa_path = prefix + "mfa_"
    mfa_text = read_2column_text(mfa_path + 'text')[uid]
    mfa_start = load_num_sequence_text(
        mfa_path + 'start', loader_type='text_float')[uid]
    mfa_end = load_num_sequence_text(
        mfa_path + 'end', loader_type='text_float')[uid]
    mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid]
P
pfZhu 已提交
263 264 265
    return mfa_text, mfa_start, mfa_end, mfa_wav_path


小湉湉's avatar
小湉湉 已提交
266 267 268 269 270
def get_masked_mel_bdy(mfa_start: List[float],
                       mfa_end: List[float],
                       fs: int,
                       hop_length: int,
                       span_to_repl: List[List[int]]):
小湉湉's avatar
小湉湉 已提交
271 272 273 274
    align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
    align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
    align_start = paddle.floor(fs * align_start / hop_length).int()
    align_end = paddle.floor(fs * align_end / hop_length).int()
小湉湉's avatar
小湉湉 已提交
275 276
    if span_to_repl[0] >= len(mfa_start):
        span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
P
pfZhu 已提交
277
    else:
小湉湉's avatar
小湉湉 已提交
278 279 280
        span_bdy = [
            align_start[0].tolist()[span_to_repl[0]],
            align_end[0].tolist()[span_to_repl[1] - 1]
小湉湉's avatar
小湉湉 已提交
281
        ]
小湉湉's avatar
小湉湉 已提交
282
    return span_bdy
P
pfZhu 已提交
283 284


小湉湉's avatar
小湉湉 已提交
285
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
O
oyjxer 已提交
286
    dic = {}
小湉湉's avatar
小湉湉 已提交
287 288
    keys_to_del = []
    exist_idx = []
小湉湉's avatar
小湉湉 已提交
289 290
    sp_count = 0
    add_sp_count = 0
O
oyjxer 已提交
291 292 293
    for key in word2phns.keys():
        idx, wrd = key.split('_')
        if wrd == 'sp':
小湉湉's avatar
小湉湉 已提交
294
            sp_count += 1
小湉湉's avatar
小湉湉 已提交
295
            exist_idx.append(int(idx))
P
pfZhu 已提交
296
        else:
小湉湉's avatar
小湉湉 已提交
297
            keys_to_del.append(key)
小湉湉's avatar
小湉湉 已提交
298

小湉湉's avatar
小湉湉 已提交
299
    for key in keys_to_del:
O
oyjxer 已提交
300 301 302 303
        del word2phns[key]

    cur_id = 0
    for key in tp_word2phns.keys():
小湉湉's avatar
小湉湉 已提交
304
        if cur_id in exist_idx:
小湉湉's avatar
小湉湉 已提交
305 306 307
            dic[str(cur_id) + "_sp"] = 'sp'
            cur_id += 1
            add_sp_count += 1
O
oyjxer 已提交
308
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
309
        dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
O
oyjxer 已提交
310
        cur_id += 1
小湉湉's avatar
小湉湉 已提交
311

O
oyjxer 已提交
312
    if add_sp_count + 1 == sp_count:
小湉湉's avatar
小湉湉 已提交
313 314 315
        dic[str(cur_id) + "_sp"] = 'sp'
        add_sp_count += 1

O
oyjxer 已提交
316 317
    assert add_sp_count == sp_count, "sp are not added in dic"
    return dic
P
pfZhu 已提交
318 319


小湉湉's avatar
小湉湉 已提交
320 321 322 323 324
def get_phns_and_spans(wav_path: str,
                       old_str: str="",
                       new_str: str="",
                       source_lang: str="english",
                       target_lang: str="english"):
P
pfZhu 已提交
325 326 327
    append_new_str = (old_str == new_str[:len(old_str)])
    old_phns, mfa_start, mfa_end = [], [], []

小湉湉's avatar
小湉湉 已提交
328
    if source_lang == "english":
小湉湉's avatar
小湉湉 已提交
329
        times2, word2phns = alignment(wav_path, old_str)
小湉湉's avatar
小湉湉 已提交
330
    elif source_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
331 332 333 334
        times2, word2phns = alignment_zh(wav_path, old_str)
        _, tp_word2phns = words2phns_zh(old_str)

        for key, value in tp_word2phns.items():
O
oyjxer 已提交
335 336
            idx, wrd = key.split('_')
            cur_val = " ".join(value)
小湉湉's avatar
小湉湉 已提交
337
            tp_word2phns[key] = cur_val
P
pfZhu 已提交
338

O
oyjxer 已提交
339
        word2phns = recover_dict(word2phns, tp_word2phns)
P
pfZhu 已提交
340

O
oyjxer 已提交
341
    else:
小湉湉's avatar
小湉湉 已提交
342
        assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..."
P
pfZhu 已提交
343

O
oyjxer 已提交
344 345 346 347
    for item in times2:
        mfa_start.append(float(item[1]))
        mfa_end.append(float(item[2]))
        old_phns.append(item[0])
P
pfZhu 已提交
348

小湉湉's avatar
小湉湉 已提交
349
    if append_new_str and (source_lang != target_lang):
小湉湉's avatar
小湉湉 已提交
350
        is_cross_lingual_clone = True
P
pfZhu 已提交
351
    else:
O
oyjxer 已提交
352
        is_cross_lingual_clone = False
P
pfZhu 已提交
353

O
oyjxer 已提交
354 355 356
    if is_cross_lingual_clone:
        new_str_origin = new_str[:len(old_str)]
        new_str_append = new_str[len(old_str):]
P
pfZhu 已提交
357

小湉湉's avatar
小湉湉 已提交
358
        if target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
359 360 361
            new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
            new_phns_append, temp_new_append_word2phns = words2phns_zh(
                new_str_append)
P
pfZhu 已提交
362

小湉湉's avatar
小湉湉 已提交
363 364
        elif target_lang == "english":
            # 原始句子
小湉湉's avatar
小湉湉 已提交
365
            new_phns_origin, new_origin_word2phns = words2phns_zh(
小湉湉's avatar
小湉湉 已提交
366 367
                new_str_origin)
            # clone句子 
小湉湉's avatar
小湉湉 已提交
368
            new_phns_append, temp_new_append_word2phns = words2phns(
小湉湉's avatar
小湉湉 已提交
369
                new_str_append)
P
pfZhu 已提交
370
        else:
小湉湉's avatar
小湉湉 已提交
371 372
            assert target_lang == "chinese" or target_lang == "english", \
                "cloning is not support for this language, please check it."
小湉湉's avatar
小湉湉 已提交
373

O
oyjxer 已提交
374 375 376 377
        new_phns = new_phns_origin + new_phns_append

        new_append_word2phns = {}
        length = len(new_origin_word2phns)
小湉湉's avatar
小湉湉 已提交
378
        for key, value in temp_new_append_word2phns.items():
O
oyjxer 已提交
379
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
380
            new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value
O
oyjxer 已提交
381

小湉湉's avatar
小湉湉 已提交
382 383 384 385 386
        new_word2phns = dict(
            list(new_origin_word2phns.items()) + list(
                new_append_word2phns.items()))

    else:
小湉湉's avatar
小湉湉 已提交
387
        if source_lang == target_lang and target_lang == "english":
O
oyjxer 已提交
388
            new_phns, new_word2phns = words2phns(new_str)
小湉湉's avatar
小湉湉 已提交
389
        elif source_lang == target_lang and target_lang == "chinese":
O
oyjxer 已提交
390 391
            new_phns, new_word2phns = words2phns_zh(new_str)
        else:
小湉湉's avatar
小湉湉 已提交
392 393
            assert source_lang == target_lang, \
                "source language is not same with target language..."
小湉湉's avatar
小湉湉 已提交
394

小湉湉's avatar
小湉湉 已提交
395 396 397
    span_to_repl = [0, len(old_phns) - 1]
    span_to_add = [0, len(new_phns) - 1]
    left_idx = 0
O
oyjxer 已提交
398 399 400 401 402
    new_phns_left = []
    sp_count = 0
    # find the left different index
    for key in word2phns.keys():
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
403 404
        if wrd == 'sp':
            sp_count += 1
O
oyjxer 已提交
405 406 407
            new_phns_left.append('sp')
        else:
            idx = str(int(idx) - sp_count)
小湉湉's avatar
小湉湉 已提交
408
            if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
409
                left_idx += len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
410
                new_phns_left.extend(word2phns[key].split())
P
pfZhu 已提交
411
            else:
小湉湉's avatar
小湉湉 已提交
412 413
                span_to_repl[0] = len(new_phns_left)
                span_to_add[0] = len(new_phns_left)
O
oyjxer 已提交
414
                break
小湉湉's avatar
小湉湉 已提交
415

O
oyjxer 已提交
416
    # reverse word2phns and new_word2phns
小湉湉's avatar
小湉湉 已提交
417
    right_idx = 0
O
oyjxer 已提交
418 419
    new_phns_right = []
    sp_count = 0
小湉湉's avatar
小湉湉 已提交
420 421 422
    word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0])
    new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0])
    new_phns_mid = []
小湉湉's avatar
小湉湉 已提交
423
    if append_new_str:
P
pfZhu 已提交
424
        new_phns_right = []
小湉湉's avatar
小湉湉 已提交
425 426 427 428 429
        new_phns_mid = new_phns[left_idx:]
        span_to_repl[0] = len(new_phns_left)
        span_to_add[0] = len(new_phns_left)
        span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
        span_to_repl[1] = len(old_phns) - len(new_phns_right)
O
oyjxer 已提交
430 431 432
    else:
        for key in list(word2phns.keys())[::-1]:
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
433 434 435
            if wrd == 'sp':
                sp_count += 1
                new_phns_right = ['sp'] + new_phns_right
P
pfZhu 已提交
436
            else:
小湉湉's avatar
小湉湉 已提交
437 438
                idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
                                                   - sp_count))
小湉湉's avatar
小湉湉 已提交
439
                if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
440
                    right_idx -= len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
441
                    new_phns_right = word2phns[key].split() + new_phns_right
P
pfZhu 已提交
442
                else:
小湉湉's avatar
小湉湉 已提交
443 444 445 446 447 448 449 450 451
                    span_to_repl[1] = len(old_phns) - len(new_phns_right)
                    new_phns_mid = new_phns[left_idx:right_idx]
                    span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
                    if len(new_phns_mid) == 0:
                        span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
                        span_to_add[0] = max(0, span_to_add[0] - 1)
                        span_to_repl[0] = max(0, span_to_repl[0] - 1)
                        span_to_repl[1] = min(span_to_repl[1] + 1,
                                              len(old_phns))
O
oyjxer 已提交
452
                    break
小湉湉's avatar
小湉湉 已提交
453
    new_phns = new_phns_left + new_phns_mid + new_phns_right
O
oyjxer 已提交
454

小湉湉's avatar
小湉湉 已提交
455
    return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
P
pfZhu 已提交
456 457


小湉湉's avatar
小湉湉 已提交
458 459 460
def duration_adjust_factor(original_dur: List[int],
                           pred_dur: List[int],
                           phns: List[str]):
P
pfZhu 已提交
461 462
    length = 0
    factor_list = []
小湉湉's avatar
小湉湉 已提交
463 464
    for ori, pred, phn in zip(original_dur, pred_dur, phns):
        if pred == 0 or phn == 'sp':
P
pfZhu 已提交
465 466
            continue
        else:
小湉湉's avatar
小湉湉 已提交
467
            factor_list.append(ori / pred)
P
pfZhu 已提交
468 469
    factor_list = np.array(factor_list)
    factor_list.sort()
小湉湉's avatar
小湉湉 已提交
470
    if len(factor_list) < 5:
P
pfZhu 已提交
471 472 473 474 475
        return 1

    length = 2
    return np.average(factor_list[length:-length])

小湉湉's avatar
小湉湉 已提交
476

小湉湉's avatar
小湉湉 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489
def prepare_features_with_duration(uid: str,
                                   prefix: str,
                                   wav_path: str,
                                   mlm_model: nn.Layer,
                                   source_lang: str="English",
                                   target_lang: str="English",
                                   old_str: str="",
                                   new_str: str="",
                                   duration_preditor_path: str=None,
                                   sid: str=None,
                                   mask_reconstruct: bool=False,
                                   duration_adjust: bool=True,
                                   start_end_sp: bool=False,
小湉湉's avatar
小湉湉 已提交
490 491 492
                                   train_args=None):
    wav_org, rate = librosa.load(
        wav_path, sr=train_args.feats_extract_conf['fs'])
P
pfZhu 已提交
493 494
    fs = train_args.feats_extract_conf['fs']
    hop_length = train_args.feats_extract_conf['hop_length']
小湉湉's avatar
小湉湉 已提交
495

小湉湉's avatar
小湉湉 已提交
496 497 498 499 500 501
    mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        source_lang=source_lang,
        target_lang=target_lang)
P
pfZhu 已提交
502 503

    if start_end_sp:
小湉湉's avatar
小湉湉 已提交
504 505 506
        if new_phns[-1] != 'sp':
            new_phns = new_phns + ['sp']

小湉湉's avatar
小湉湉 已提交
507 508
    if target_lang == "english":
        old_durations = evaluate_durations(old_phns, target_lang=target_lang)
P
pfZhu 已提交
509

小湉湉's avatar
小湉湉 已提交
510
    elif target_lang == "chinese":
P
pfZhu 已提交
511

小湉湉's avatar
小湉湉 已提交
512
        if source_lang == "english":
小湉湉's avatar
小湉湉 已提交
513
            old_durations = evaluate_durations(
小湉湉's avatar
小湉湉 已提交
514
                old_phns, target_lang=source_lang)
P
pfZhu 已提交
515

小湉湉's avatar
小湉湉 已提交
516
        elif source_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
517
            old_durations = evaluate_durations(
小湉湉's avatar
小湉湉 已提交
518
                old_phns, target_lang=source_lang)
P
pfZhu 已提交
519 520

    else:
小湉湉's avatar
小湉湉 已提交
521
        assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..."
P
pfZhu 已提交
522

小湉湉's avatar
小湉湉 已提交
523
    original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
P
pfZhu 已提交
524 525
    if '[MASK]' in new_str:
        new_phns = old_phns
小湉湉's avatar
小湉湉 已提交
526
        span_to_add = span_to_repl
小湉湉's avatar
小湉湉 已提交
527
        d_factor_left = duration_adjust_factor(
小湉湉's avatar
小湉湉 已提交
528 529
            original_old_durations[:span_to_repl[0]],
            old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]])
小湉湉's avatar
小湉湉 已提交
530
        d_factor_right = duration_adjust_factor(
小湉湉's avatar
小湉湉 已提交
531 532
            original_old_durations[span_to_repl[1]:],
            old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:])
小湉湉's avatar
小湉湉 已提交
533 534
        d_factor = (d_factor_left + d_factor_right) / 2
        new_durations_adjusted = [d_factor * i for i in old_durations]
P
pfZhu 已提交
535 536
    else:
        if duration_adjust:
小湉湉's avatar
小湉湉 已提交
537 538 539
            d_factor = duration_adjust_factor(original_old_durations,
                                              old_durations, old_phns)
            d_factor = d_factor * 1.25
P
pfZhu 已提交
540 541 542
        else:
            d_factor = 1

小湉湉's avatar
小湉湉 已提交
543
        if target_lang == "english":
小湉湉's avatar
小湉湉 已提交
544
            new_durations = evaluate_durations(
小湉湉's avatar
小湉湉 已提交
545
                new_phns, target_lang=target_lang)
P
pfZhu 已提交
546

小湉湉's avatar
小湉湉 已提交
547
        elif target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
548
            new_durations = evaluate_durations(
小湉湉's avatar
小湉湉 已提交
549
                new_phns, target_lang=target_lang)
P
pfZhu 已提交
550

小湉湉's avatar
小湉湉 已提交
551
        new_durations_adjusted = [d_factor * i for i in new_durations]
P
pfZhu 已提交
552

小湉湉's avatar
小湉湉 已提交
553 554 555 556 557 558 559 560
        if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[
                0]] == new_phns[span_to_add[0]]:
            new_durations_adjusted[span_to_add[0]] = original_old_durations[
                span_to_repl[0]]
        if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns):
            if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]:
                new_durations_adjusted[span_to_add[1]] = original_old_durations[
                    span_to_repl[1]]
小湉湉's avatar
小湉湉 已提交
561
    new_span_duration_sum = sum(
小湉湉's avatar
小湉湉 已提交
562
        new_durations_adjusted[span_to_add[0]:span_to_add[1]])
小湉湉's avatar
小湉湉 已提交
563
    old_span_duration_sum = sum(
小湉湉's avatar
小湉湉 已提交
564
        original_old_durations[span_to_repl[0]:span_to_repl[1]])
小湉湉's avatar
小湉湉 已提交
565
    duration_offset = new_span_duration_sum - old_span_duration_sum
小湉湉's avatar
小湉湉 已提交
566 567 568
    new_mfa_start = mfa_start[:span_to_repl[0]]
    new_mfa_end = mfa_end[:span_to_repl[0]]
    for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]:
小湉湉's avatar
小湉湉 已提交
569
        if len(new_mfa_end) == 0:
P
pfZhu 已提交
570 571 572 573
            new_mfa_start.append(0)
            new_mfa_end.append(i)
        else:
            new_mfa_start.append(new_mfa_end[-1])
小湉湉's avatar
小湉湉 已提交
574
            new_mfa_end.append(new_mfa_end[-1] + i)
小湉湉's avatar
小湉湉 已提交
575 576
    new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]]
    new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]]
小湉湉's avatar
小湉湉 已提交
577

P
pfZhu 已提交
578
    # 3. get new wav 
小湉湉's avatar
小湉湉 已提交
579 580 581
    if span_to_repl[0] >= len(mfa_start):
        left_idx = len(wav_org)
        right_idx = left_idx
P
pfZhu 已提交
582
    else:
小湉湉's avatar
小湉湉 已提交
583 584
        left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
        right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
小湉湉's avatar
小湉湉 已提交
585 586 587
    new_blank_wav = np.zeros(
        (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
    new_wav_org = np.concatenate(
小湉湉's avatar
小湉湉 已提交
588
        [wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]])
P
pfZhu 已提交
589 590

    # 4. get old and new mel span to be mask
小湉湉's avatar
小湉湉 已提交
591 592 593 594 595 596 597 598 599 600 601 602
    # [92, 92]
    old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length,
                                      span_to_repl)
    # [92, 174]
    new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs,
                                      hop_length, span_to_add)

    return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy


def prepare_features(uid: str,
                     mlm_model: nn.Layer,
小湉湉's avatar
小湉湉 已提交
603
                     processor,
小湉湉's avatar
小湉湉 已提交
604 605 606 607 608 609 610 611 612 613 614
                     wav_path: str,
                     prefix: str="./prompt/dev/",
                     source_lang: str="english",
                     target_lang: str="english",
                     old_str: str="",
                     new_str: str="",
                     duration_preditor_path: str=None,
                     sid: str=None,
                     duration_adjust: bool=True,
                     start_end_sp: bool=False,
                     mask_reconstruct: bool=False,
小湉湉's avatar
小湉湉 已提交
615
                     train_args=None):
小湉湉's avatar
小湉湉 已提交
616 617 618 619 620 621 622 623 624 625
    wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration(
        uid=uid,
        prefix=prefix,
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        old_str=old_str,
        new_str=new_str,
        wav_path=wav_path,
        duration_preditor_path=duration_preditor_path,
小湉湉's avatar
小湉湉 已提交
626 627 628 629 630
        sid=sid,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        mask_reconstruct=mask_reconstruct,
        train_args=train_args)
小湉湉's avatar
小湉湉 已提交
631
    speech = wav_org
小湉湉's avatar
小湉湉 已提交
632 633
    align_start = np.array(mfa_start)
    align_end = np.array(mfa_end)
P
pfZhu 已提交
634
    token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
小湉湉's avatar
小湉湉 已提交
635 636 637
    text = np.array(
        list(
            map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
小湉湉's avatar
小湉湉 已提交
638 639

    span_bdy = np.array(new_span_bdy)
小湉湉's avatar
小湉湉 已提交
640 641 642 643 644
    batch = [('1', {
        "speech": speech,
        "align_start": align_start,
        "align_end": align_end,
        "text": text,
小湉湉's avatar
小湉湉 已提交
645
        "span_bdy": span_bdy
小湉湉's avatar
小湉湉 已提交
646 647
    })]

小湉湉's avatar
小湉湉 已提交
648
    return batch, old_span_bdy, new_span_bdy
P
pfZhu 已提交
649 650


小湉湉's avatar
小湉湉 已提交
651 652
def decode_with_model(uid: str,
                      mlm_model: nn.Layer,
小湉湉's avatar
小湉湉 已提交
653 654
                      processor,
                      collate_fn,
小湉湉's avatar
小湉湉 已提交
655 656 657 658 659 660 661 662 663 664 665 666
                      wav_path: str,
                      prefix: str="./prompt/dev/",
                      source_lang: str="english",
                      target_lang: str="english",
                      old_str: str="",
                      new_str: str="",
                      duration_preditor_path: str=None,
                      sid: str=None,
                      decoder: bool=False,
                      use_teacher_forcing: bool=False,
                      duration_adjust: bool=True,
                      start_end_sp: bool=False,
小湉湉's avatar
小湉湉 已提交
667 668 669 670
                      train_args=None):
    fs, hop_length = train_args.feats_extract_conf[
        'fs'], train_args.feats_extract_conf['hop_length']

小湉湉's avatar
小湉湉 已提交
671 672 673 674 675 676 677 678 679 680 681 682
    batch, old_span_bdy, new_span_bdy = prepare_features(
        uid=uid,
        prefix=prefix,
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        processor=processor,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        duration_preditor_path=duration_preditor_path,
        sid=sid,
小湉湉's avatar
小湉湉 已提交
683 684 685 686
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        train_args=train_args)

O
oyjxer 已提交
687 688
    feats = collate_fn(batch)[1]

小湉湉's avatar
小湉湉 已提交
689 690
    if 'text_masked_pos' in feats.keys():
        feats.pop('text_masked_pos')
P
pfZhu 已提交
691 692
    for k, v in feats.items():
        feats[k] = paddle.to_tensor(v)
小湉湉's avatar
小湉湉 已提交
693
    rtn = mlm_model.inference(
小湉湉's avatar
小湉湉 已提交
694
        **feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing)
小湉湉's avatar
小湉湉 已提交
695
    output = rtn['feat_gen']
P
pfZhu 已提交
696
    if 0 in output[0].shape and 0 not in output[-1].shape:
小湉湉's avatar
小湉湉 已提交
697 698
        output_feat = paddle.concat(
            output[1:-1] + [output[-1].squeeze()], axis=0).cpu()
P
pfZhu 已提交
699
    elif 0 not in output[0].shape and 0 in output[-1].shape:
小湉湉's avatar
小湉湉 已提交
700 701
        output_feat = paddle.concat(
            [output[0].squeeze()] + output[1:-1], axis=0).cpu()
P
pfZhu 已提交
702 703 704
    elif 0 in output[0].shape and 0 in output[-1].shape:
        output_feat = paddle.concat(output[1:-1], axis=0).cpu()
    else:
小湉湉's avatar
小湉湉 已提交
705 706 707 708
        output_feat = paddle.concat(
            [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
            axis=0).cpu()

小湉湉's avatar
小湉湉 已提交
709
    wav_org, _ = librosa.load(
小湉湉's avatar
小湉湉 已提交
710
        wav_path, sr=train_args.feats_extract_conf['fs'])
小湉湉's avatar
小湉湉 已提交
711
    return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
P
pfZhu 已提交
712 713


O
oyjxer 已提交
714 715 716
class MLMCollateFn:
    """Functor class of common_collate_fn()"""

小湉湉's avatar
小湉湉 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729 730
    def __init__(self,
                 feats_extract,
                 float_pad_value: Union[float, int]=0.0,
                 int_pad_value: int=-32768,
                 not_sequence: Collection[str]=(),
                 mlm_prob: float=0.8,
                 mean_phn_span: int=8,
                 attention_window: int=0,
                 pad_speech: bool=False,
                 sega_emb: bool=False,
                 duration_collect: bool=False,
                 text_masking: bool=False):
        self.mlm_prob = mlm_prob
        self.mean_phn_span = mean_phn_span
O
oyjxer 已提交
731 732 733 734
        self.feats_extract = feats_extract
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
        self.not_sequence = set(not_sequence)
小湉湉's avatar
小湉湉 已提交
735 736 737
        self.attention_window = attention_window
        self.pad_speech = pad_speech
        self.sega_emb = sega_emb
O
oyjxer 已提交
738 739 740 741
        self.duration_collect = duration_collect
        self.text_masking = text_masking

    def __repr__(self):
小湉湉's avatar
小湉湉 已提交
742 743 744 745 746
        return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
                f"int_pad_value={self.float_pad_value})")

    def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
                 ) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
O
oyjxer 已提交
747 748 749 750 751
        return mlm_collate_fn(
            data,
            float_pad_value=self.float_pad_value,
            int_pad_value=self.int_pad_value,
            not_sequence=self.not_sequence,
小湉湉's avatar
小湉湉 已提交
752
            mlm_prob=self.mlm_prob,
O
oyjxer 已提交
753 754 755 756 757 758
            mean_phn_span=self.mean_phn_span,
            feats_extract=self.feats_extract,
            attention_window=self.attention_window,
            pad_speech=self.pad_speech,
            sega_emb=self.sega_emb,
            duration_collect=self.duration_collect,
小湉湉's avatar
小湉湉 已提交
759 760
            text_masking=self.text_masking)

O
oyjxer 已提交
761 762

def mlm_collate_fn(
小湉湉's avatar
小湉湉 已提交
763 764 765 766 767 768 769 770 771 772 773 774
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int]=0.0,
        int_pad_value: int=-32768,
        not_sequence: Collection[str]=(),
        mlm_prob: float=0.8,
        mean_phn_span: int=8,
        feats_extract=None,
        attention_window: int=0,
        pad_speech: bool=False,
        sega_emb: bool=False,
        duration_collect: bool=False,
        text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
O
oyjxer 已提交
775 776 777 778
    uttids = [u for u, _ in data]
    data = [d for _, d in data]

    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
小湉湉's avatar
小湉湉 已提交
779 780
    assert all(not k.endswith("_lens")
               for k in data[0]), f"*_lens is reserved: {list(data[0])}"
O
oyjxer 已提交
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801

    output = {}
    for key in data[0]:
        # Each models, which accepts these values finally, are responsible
        # to repaint the pad_value to the desired value for each tasks.
        if data[0][key].dtype.kind == "i":
            pad_value = int_pad_value
        else:
            pad_value = float_pad_value

        array_list = [d[key] for d in data]

        # Assume the first axis is length:
        # tensor_list: Batch x (Length, ...)
        tensor_list = [paddle.to_tensor(a) for a in array_list]
        # tensor: (Batch, Length, ...)
        tensor = pad_list(tensor_list, pad_value)
        output[key] = tensor

        # lens: (Batch,)
        if key not in not_sequence:
小湉湉's avatar
小湉湉 已提交
802
            lens = paddle.to_tensor(
小湉湉's avatar
小湉湉 已提交
803 804
                [d[key].shape[0] for d in data], dtype=paddle.int64)
            output[key + "_lens"] = lens
O
oyjxer 已提交
805 806 807

    feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
    feats = paddle.to_tensor(feats)
小湉湉's avatar
小湉湉 已提交
808
    feats_lens = paddle.shape(feats)[0]
O
oyjxer 已提交
809 810
    feats = paddle.unsqueeze(feats, 0)
    if 'text' not in output:
小湉湉's avatar
小湉湉 已提交
811 812
        text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2
        text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1
小湉湉's avatar
小湉湉 已提交
813
        max_tlen = 1
小湉湉's avatar
小湉湉 已提交
814 815 816
        align_start = paddle.zeros(paddle.shape(text))
        align_end = paddle.zeros(paddle.shape(text))
        align_start_lens = paddle.zeros(paddle.shape(feats_lens))
小湉湉's avatar
小湉湉 已提交
817
        sega_emb = False
O
oyjxer 已提交
818 819 820
        mean_phn_span = 0
        mlm_prob = 0.15
    else:
小湉湉's avatar
小湉湉 已提交
821 822 823 824 825
        text = output["text"]
        text_lens = output["text_lens"]
        align_start = output["align_start"]
        align_start_lens = output["align_start_lens"]
        align_end = output["align_end"]
小湉湉's avatar
小湉湉 已提交
826 827 828 829
        align_start = paddle.floor(feats_extract.sr * align_start /
                                   feats_extract.hop_length).int()
        align_end = paddle.floor(feats_extract.sr * align_end /
                                 feats_extract.hop_length).int()
小湉湉's avatar
小湉湉 已提交
830 831
        max_tlen = max(text_lens)
    max_slen = max(feats_lens)
小湉湉's avatar
小湉湉 已提交
832 833 834 835
    speech_pad = feats[:, :max_slen]
    if attention_window > 0 and pad_speech:
        speech_pad, max_slen = pad_to_longformer_att_window(
            speech_pad, max_slen, max_slen, attention_window)
O
oyjxer 已提交
836
    max_len = max_slen + max_tlen
小湉湉's avatar
小湉湉 已提交
837 838 839
    if attention_window > 0:
        text_pad, max_tlen = pad_to_longformer_att_window(
            text, max_len, max_tlen, attention_window)
O
oyjxer 已提交
840 841
    else:
        text_pad = text
小湉湉's avatar
小湉湉 已提交
842
    text_mask = make_non_pad_mask(
小湉湉's avatar
小湉湉 已提交
843
        text_lens, text_pad, length_dim=1).unsqueeze(-2)
小湉湉's avatar
小湉湉 已提交
844 845 846
    if attention_window > 0:
        text_mask = text_mask * 2
    speech_mask = make_non_pad_mask(
小湉湉's avatar
小湉湉 已提交
847 848 849 850
        feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
    span_bdy = None
    if 'span_bdy' in output.keys():
        span_bdy = output['span_bdy']
O
oyjxer 已提交
851 852

    if text_masking:
小湉湉's avatar
小湉湉 已提交
853
        masked_pos, text_masked_pos, _ = phones_text_masking(
小湉湉's avatar
小湉湉 已提交
854
            speech_pad, speech_mask, text_pad, text_mask, align_start,
小湉湉's avatar
小湉湉 已提交
855
            align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy)
O
oyjxer 已提交
856
    else:
小湉湉's avatar
小湉湉 已提交
857 858 859 860
        text_masked_pos = paddle.zeros(paddle.shape(text_pad))
        masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start,
                                       align_end, align_start_lens, mlm_prob,
                                       mean_phn_span, span_bdy)
O
oyjxer 已提交
861 862 863

    output_dict = {}
    if duration_collect and 'text' in output:
小湉湉's avatar
小湉湉 已提交
864 865 866
        reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration(
            speech_pad, text_pad, align_start, align_end, align_start_lens,
            sega_emb, masked_pos, feats_lens)
小湉湉's avatar
小湉湉 已提交
867
        speech_mask = make_non_pad_mask(
小湉湉's avatar
小湉湉 已提交
868
            feats_lens, speech_pad[:, :reordered_idx.shape[1], 0],
小湉湉's avatar
小湉湉 已提交
869
            length_dim=1).unsqueeze(-2)
O
oyjxer 已提交
870
        output_dict['durations'] = durations
小湉湉's avatar
小湉湉 已提交
871
        output_dict['reordered_idx'] = reordered_idx
O
oyjxer 已提交
872
    else:
小湉湉's avatar
小湉湉 已提交
873 874 875
        speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad,
                                                   align_start, align_end,
                                                   align_start_lens, sega_emb)
O
oyjxer 已提交
876 877
    output_dict['speech'] = speech_pad
    output_dict['text'] = text_pad
小湉湉's avatar
小湉湉 已提交
878 879
    output_dict['masked_pos'] = masked_pos
    output_dict['text_masked_pos'] = text_masked_pos
O
oyjxer 已提交
880 881
    output_dict['speech_mask'] = speech_mask
    output_dict['text_mask'] = text_mask
小湉湉's avatar
小湉湉 已提交
882 883 884 885
    output_dict['speech_seg_pos'] = speech_seg_pos
    output_dict['text_seg_pos'] = text_seg_pos
    output_dict['speech_lens'] = output["speech_lens"]
    output_dict['text_lens'] = text_lens
O
oyjxer 已提交
886 887 888
    output = (uttids, output_dict)
    return output

小湉湉's avatar
小湉湉 已提交
889 890

def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
O
oyjxer 已提交
891 892
    # -> Callable[
    #     [Collection[Tuple[str, Dict[str, np.ndarray]]]],
小湉湉's avatar
小湉湉 已提交
893
    #     Tuple[List[str], Dict[str, Tensor]],
O
oyjxer 已提交
894 895 896 897 898
    # ]:

    # assert check_argument_types()
    # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
    feats_extract_class = LogMelFBank
小湉湉's avatar
小湉湉 已提交
899
    if args.feats_extract_conf['win_length'] is None:
O
oyjxer 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912
        args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']

    args_dic = {}
    for k, v in args.feats_extract_conf.items():
        if k == 'fs':
            args_dic['sr'] = v
        else:
            args_dic[k] = v
    feats_extract = feats_extract_class(**args_dic)

    sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
    if args.encoder_conf['selfattention_layer_type'] == 'longformer':
        attention_window = args.encoder_conf['attention_window']
小湉湉's avatar
小湉湉 已提交
913 914
        pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
            'pre_speech_layer'] > 0 else False
O
oyjxer 已提交
915
    else:
小湉湉's avatar
小湉湉 已提交
916 917 918
        attention_window = 0
        pad_speech = False
    if epoch == -1:
O
oyjxer 已提交
919 920
        mlm_prob_factor = 1
    else:
小湉湉's avatar
小湉湉 已提交
921
        mlm_prob_factor = 0.8
小湉湉's avatar
小湉湉 已提交
922 923 924
    if 'duration_predictor_layers' in args.model_conf.keys(
    ) and args.model_conf['duration_predictor_layers'] > 0:
        duration_collect = True
O
oyjxer 已提交
925
    else:
小湉湉's avatar
小湉湉 已提交
926
        duration_collect = False
O
oyjxer 已提交
927

小湉湉's avatar
小湉湉 已提交
928 929 930 931 932 933 934 935 936 937 938 939
    return MLMCollateFn(
        feats_extract,
        float_pad_value=0.0,
        int_pad_value=0,
        mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
        mean_phn_span=args.model_conf['mean_phn_span'],
        attention_window=attention_window,
        pad_speech=pad_speech,
        sega_emb=sega_emb,
        duration_collect=duration_collect)


小湉湉's avatar
小湉湉 已提交
940 941 942 943 944 945 946 947 948 949 950 951 952 953
def get_mlm_output(uid: str,
                   wav_path: str,
                   prefix: str="./prompt/dev/",
                   model_name: str="conformer",
                   source_lang: str="english",
                   target_lang: str="english",
                   old_str: str="",
                   new_str: str="",
                   duration_preditor_path: str=None,
                   sid: str=None,
                   decoder: bool=False,
                   use_teacher_forcing: bool=False,
                   duration_adjust: bool=True,
                   start_end_sp: bool=False):
小湉湉's avatar
小湉湉 已提交
954
    mlm_model, train_args = load_model(model_name)
P
pfZhu 已提交
955 956
    mlm_model.eval()
    processor = None
O
oyjxer 已提交
957
    collate_fn = build_collate_fn(train_args, False)
P
pfZhu 已提交
958

小湉湉's avatar
小湉湉 已提交
959
    return decode_with_model(
小湉湉's avatar
小湉湉 已提交
960 961 962 963 964 965 966 967 968 969 970
        uid=uid,
        prefix=prefix,
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        processor=processor,
        collate_fn=collate_fn,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        duration_preditor_path=duration_preditor_path,
小湉湉's avatar
小湉湉 已提交
971 972 973 974 975 976 977 978
        sid=sid,
        decoder=decoder,
        use_teacher_forcing=use_teacher_forcing,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        train_args=train_args)


小湉湉's avatar
小湉湉 已提交
979 980 981 982 983 984 985 986 987 988
def evaluate(uid: str,
             source_lang: str="english",
             target_lang: str="english",
             use_pt_vocoder: bool=False,
             prefix: str="./prompt/dev/",
             model_name: str="conformer",
             old_str: str="",
             new_str: str="",
             prompt_decoding: bool=False,
             task_name: str=None):
P
pfZhu 已提交
989 990 991

    duration_preditor_path = None
    spemd = None
小湉湉's avatar
小湉湉 已提交
992
    full_origin_str, wav_path = read_data(uid=uid, prefix=prefix)
小湉湉's avatar
小湉湉 已提交
993

O
oyjxer 已提交
994 995 996
    if task_name == 'edit':
        new_str = new_str
    elif task_name == 'synthesize':
小湉湉's avatar
小湉湉 已提交
997
        new_str = full_origin_str + new_str
O
oyjxer 已提交
998
    else:
小湉湉's avatar
小湉湉 已提交
999 1000 1001
        new_str = full_origin_str + ' '.join(
            [ch for ch in new_str if is_chinese(ch)])

P
pfZhu 已提交
1002
    print('new_str is ', new_str)
小湉湉's avatar
小湉湉 已提交
1003

P
pfZhu 已提交
1004 1005 1006
    if not old_str:
        old_str = full_origin_str

小湉湉's avatar
小湉湉 已提交
1007
    results_dict, old_span = plot_mel_and_vocode_wav(
小湉湉's avatar
小湉湉 已提交
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
        uid=uid,
        prefix=prefix,
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        full_origin_str=full_origin_str,
        old_str=old_str,
        new_str=new_str,
        use_pt_vocoder=use_pt_vocoder,
        duration_preditor_path=duration_preditor_path,
小湉湉's avatar
小湉湉 已提交
1019
        sid=spemd)
P
pfZhu 已提交
1020 1021
    return results_dict

小湉湉's avatar
小湉湉 已提交
1022

P
pfZhu 已提交
1023
if __name__ == "__main__":
小湉湉's avatar
小湉湉 已提交
1024
    # parse config and args
P
pfZhu 已提交
1025
    args = parse_args()
小湉湉's avatar
小湉湉 已提交
1026

小湉湉's avatar
小湉湉 已提交
1027 1028 1029 1030 1031 1032 1033
    data_dict = evaluate(
        uid=args.uid,
        source_lang=args.source_lang,
        target_lang=args.target_lang,
        use_pt_vocoder=args.use_pt_vocoder,
        prefix=args.prefix,
        model_name=args.model_name,
P
pfZhu 已提交
1034 1035
        new_str=args.new_str,
        task_name=args.task_name)
小湉湉's avatar
小湉湉 已提交
1036
    sf.write(args.output_name, data_dict['output'], samplerate=24000)
O
oyjxer 已提交
1037
    print("finished...")