inference.py 39.2 KB
Newer Older
P
pfZhu 已提交
1 2
#!/usr/bin/env python3
import argparse
小湉湉's avatar
小湉湉 已提交
3 4 5 6 7 8 9
import math
import os
import pickle
import random
import string
import sys
from pathlib import Path
P
pfZhu 已提交
10 11 12 13 14
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
小湉湉's avatar
小湉湉 已提交
15 16 17

import librosa
import numpy as np
O
oyjxer 已提交
18
import paddle
小湉湉's avatar
小湉湉 已提交
19
import soundfile as sf
O
oyjxer 已提交
20
import torch
小湉湉's avatar
小湉湉 已提交
21

O
oyjxer 已提交
22
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
P
pfZhu 已提交
23

O
oyjxer 已提交
24 25
from align_english import alignment
from align_mandarin import alignment_zh
小湉湉's avatar
小湉湉 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
from dataset import get_segment_pos
from dataset import make_non_pad_mask
from dataset import make_pad_mask
from dataset import pad_list
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from utils import sentence2phns
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
P
pfZhu 已提交
42 43 44
random.seed(0)
np.random.seed(0)

O
oyjxer 已提交
45 46 47 48 49
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'


小湉湉's avatar
小湉湉 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def plot_mel_and_vocode_wav(uid,
                            prefix,
                            clone_uid,
                            clone_prefix,
                            source_language,
                            target_language,
                            model_name,
                            wav_path,
                            full_origin_str,
                            old_str,
                            new_str,
                            use_pt_vocoder,
                            duration_preditor_path,
                            sid=None,
                            non_autoreg=True):
P
pfZhu 已提交
65
    wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output(
小湉湉's avatar
小湉湉 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        uid,
        prefix,
        clone_uid,
        clone_prefix,
        source_language,
        target_language,
        model_name,
        wav_path,
        old_str,
        new_str,
        duration_preditor_path,
        use_teacher_forcing=non_autoreg,
        sid=sid)

    masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[
        1]].detach().float().cpu().numpy()

P
pfZhu 已提交
83
    if target_language == 'english':
O
oyjxer 已提交
84 85
        if use_pt_vocoder:
            output_feat = output_feat.detach().float().cpu().numpy()
小湉湉's avatar
小湉湉 已提交
86
            output_feat = torch.tensor(output_feat, dtype=torch.float)
O
oyjxer 已提交
87
            vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
小湉湉's avatar
小湉湉 已提交
88 89
            replaced_wav = vocoder(
                output_feat).detach().float().data.cpu().numpy()
O
oyjxer 已提交
90 91 92
        else:
            output_feat_np = output_feat.detach().float().cpu().numpy()
            replaced_wav = get_voc_out(output_feat_np, target_language)
P
pfZhu 已提交
93 94 95

    elif target_language == 'chinese':
        output_feat_np = output_feat.detach().float().cpu().numpy()
小湉湉's avatar
小湉湉 已提交
96 97 98 99 100 101
        replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat,
                                                      target_language)

    old_time_boundary = [hop_length * x for x in old_span_boundary]
    new_time_boundary = [hop_length * x for x in new_span_boundary]

P
pfZhu 已提交
102
    if target_language == 'english':
小湉湉's avatar
小湉湉 已提交
103 104 105 106 107
        wav_org_replaced_paddle_voc = np.concatenate([
            wav_org[:old_time_boundary[0]],
            replaced_wav[new_time_boundary[0]:new_time_boundary[1]],
            wav_org[old_time_boundary[1]:]
        ])
P
pfZhu 已提交
108

小湉湉's avatar
小湉湉 已提交
109
        data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
P
pfZhu 已提交
110

小湉湉's avatar
小湉湉 已提交
111 112 113 114 115
    elif target_language == 'chinese':
        wav_org_replaced_only_mask_fst2_voc = np.concatenate([
            wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc,
            wav_org[old_time_boundary[1]:]
        ])
P
pfZhu 已提交
116
        data_dict = {
小湉湉's avatar
小湉湉 已提交
117 118 119
            "origin": wav_org,
            "output": wav_org_replaced_only_mask_fst2_voc,
        }
P
pfZhu 已提交
120

小湉湉's avatar
小湉湉 已提交
121
    return data_dict, old_span_boundary
P
pfZhu 已提交
122

O
oyjxer 已提交
123 124 125 126 127 128

def get_unk_phns(word_str):
    tmpbase = '/tmp/tp.'
    f = open(tmpbase + 'temp.words', 'w')
    f.write(word_str)
    f.close()
小湉湉's avatar
小湉湉 已提交
129 130
    os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
              'temp.phons')
O
oyjxer 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    f = open(tmpbase + 'temp.phons', 'r')
    lines2 = f.readline().strip().split()
    f.close()
    phns = []
    for phn in lines2:
        phons = phn.replace('\n', '').replace(' ', '')
        seq = []
        j = 0
        while (j < len(phons)):
            if (phons[j] > 'Z'):
                if (phons[j] == 'j'):
                    seq.append('JH')
                elif (phons[j] == 'h'):
                    seq.append('HH')
                else:
                    seq.append(phons[j].upper())
                j += 1
            else:
小湉湉's avatar
小湉湉 已提交
149
                p = phons[j:j + 2]
O
oyjxer 已提交
150 151 152 153 154 155 156 157 158 159 160 161
                if (p == 'WH'):
                    seq.append('W')
                elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
                    seq.append(p)
                elif (p == 'AX'):
                    seq.append('AH0')
                else:
                    seq.append(p + '1')
                j += 2
        phns.extend(seq)
    return phns

小湉湉's avatar
小湉湉 已提交
162

O
oyjxer 已提交
163
def words2phns(line):
小湉湉's avatar
小湉湉 已提交
164
    dictfile = MODEL_DIR_EN + '/dict'
O
oyjxer 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    tmpbase = '/tmp/tp.'
    line = line.strip()
    words = []
    for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
        line = line.replace(pun, ' ')
    for wrd in line.split():
        if (wrd[-1] == '-'):
            wrd = wrd[:-1]
        if (wrd[0] == "'"):
            wrd = wrd[1:]
        if wrd:
            words.append(wrd)
    ds = set([])
    word2phns_dict = {}
    with open(dictfile, 'r') as fid:
        for line in fid:
            word = line.split()[0]
            ds.add(word)
            if word not in word2phns_dict.keys():
                word2phns_dict[word] = " ".join(line.split()[1:])
小湉湉's avatar
小湉湉 已提交
185

O
oyjxer 已提交
186 187 188 189
    phns = []
    wrd2phns = {}
    for index, wrd in enumerate(words):
        if wrd == '[MASK]':
小湉湉's avatar
小湉湉 已提交
190
            wrd2phns[str(index) + "_" + wrd] = [wrd]
O
oyjxer 已提交
191 192
            phns.append(wrd)
        elif (wrd.upper() not in ds):
小湉湉's avatar
小湉湉 已提交
193
            wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
O
oyjxer 已提交
194 195
            phns.extend(get_unk_phns(wrd))
        else:
小湉湉's avatar
小湉湉 已提交
196 197
            wrd2phns[str(index) +
                     "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
O
oyjxer 已提交
198 199 200 201 202 203
            phns.extend(word2phns_dict[wrd.upper()].split())

    return phns, wrd2phns


def words2phns_zh(line):
小湉湉's avatar
小湉湉 已提交
204
    dictfile = MODEL_DIR_ZH + '/dict'
O
oyjxer 已提交
205 206 207
    tmpbase = '/tmp/tp.'
    line = line.strip()
    words = []
小湉湉's avatar
小湉湉 已提交
208 209 210 211
    for pun in [
            ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
            u'。', u':', u';', u'!', u'?', u'(', u')'
    ]:
O
oyjxer 已提交
212 213 214 215 216 217 218 219
        line = line.replace(pun, ' ')
    for wrd in line.split():
        if (wrd[-1] == '-'):
            wrd = wrd[:-1]
        if (wrd[0] == "'"):
            wrd = wrd[1:]
        if wrd:
            words.append(wrd)
小湉湉's avatar
小湉湉 已提交
220

O
oyjxer 已提交
221 222 223 224 225 226 227 228
    ds = set([])
    word2phns_dict = {}
    with open(dictfile, 'r') as fid:
        for line in fid:
            word = line.split()[0]
            ds.add(word)
            if word not in word2phns_dict.keys():
                word2phns_dict[word] = " ".join(line.split()[1:])
小湉湉's avatar
小湉湉 已提交
229

O
oyjxer 已提交
230 231 232 233
    phns = []
    wrd2phns = {}
    for index, wrd in enumerate(words):
        if wrd == '[MASK]':
小湉湉's avatar
小湉湉 已提交
234
            wrd2phns[str(index) + "_" + wrd] = [wrd]
O
oyjxer 已提交
235 236 237 238
            phns.append(wrd)
        elif (wrd.upper() not in ds):
            print("出现非法词错误,请输入正确的文本...")
        else:
小湉湉's avatar
小湉湉 已提交
239
            wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
O
oyjxer 已提交
240 241 242 243 244 245 246 247 248
            phns.extend(word2phns_dict[wrd].split())

    return phns, wrd2phns


def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
    vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
    vocoder_file = download_pretrained_model(vocoder_tag)
    vocoder_config = Path(vocoder_file).parent / "config.yml"
小湉湉's avatar
小湉湉 已提交
249
    vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
O
oyjxer 已提交
250 251
    return vocoder

小湉湉's avatar
小湉湉 已提交
252

P
pfZhu 已提交
253
def load_model(model_name):
小湉湉's avatar
小湉湉 已提交
254
    config_path = './pretrained_model/{}/config.yaml'.format(model_name)
P
pfZhu 已提交
255
    model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
小湉湉's avatar
小湉湉 已提交
256 257
    mlm_model, args = build_model_from_file(
        config_file=config_path, model_file=model_path)
P
pfZhu 已提交
258 259 260
    return mlm_model, args


小湉湉's avatar
小湉湉 已提交
261 262 263
def read_data(uid, prefix):
    mfa_text = read_2column_text(prefix + '/text')[uid]
    mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
P
pfZhu 已提交
264 265 266
    if 'mnt' not in mfa_wav_path:
        mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path
    return mfa_text, mfa_wav_path
小湉湉's avatar
小湉湉 已提交
267 268 269 270 271 272 273 274 275 276


def get_align_data(uid, prefix):
    mfa_path = prefix + "mfa_"
    mfa_text = read_2column_text(mfa_path + 'text')[uid]
    mfa_start = load_num_sequence_text(
        mfa_path + 'start', loader_type='text_float')[uid]
    mfa_end = load_num_sequence_text(
        mfa_path + 'end', loader_type='text_float')[uid]
    mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid]
P
pfZhu 已提交
277 278 279
    return mfa_text, mfa_start, mfa_end, mfa_wav_path


小湉湉's avatar
小湉湉 已提交
280 281 282 283 284 285 286 287
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length,
                            span_tobe_replaced):
    align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
    align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
    align_start = paddle.floor(fs * align_start / hop_length).int()
    align_end = paddle.floor(fs * align_end / hop_length).int()
    if span_tobe_replaced[0] >= len(mfa_start):
        span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
P
pfZhu 已提交
288
    else:
小湉湉's avatar
小湉湉 已提交
289 290 291 292
        span_boundary = [
            align_start[0].tolist()[span_tobe_replaced[0]],
            align_end[0].tolist()[span_tobe_replaced[1] - 1]
        ]
P
pfZhu 已提交
293 294 295
    return span_boundary


O
oyjxer 已提交
296 297 298
def recover_dict(word2phns, tp_word2phns):
    dic = {}
    need_del_key = []
小湉湉's avatar
小湉湉 已提交
299 300 301
    exist_index = []
    sp_count = 0
    add_sp_count = 0
O
oyjxer 已提交
302 303 304
    for key in word2phns.keys():
        idx, wrd = key.split('_')
        if wrd == 'sp':
小湉湉's avatar
小湉湉 已提交
305
            sp_count += 1
O
oyjxer 已提交
306
            exist_index.append(int(idx))
P
pfZhu 已提交
307
        else:
O
oyjxer 已提交
308
            need_del_key.append(key)
小湉湉's avatar
小湉湉 已提交
309

O
oyjxer 已提交
310 311 312 313 314 315 316
    for key in need_del_key:
        del word2phns[key]

    cur_id = 0
    for key in tp_word2phns.keys():
        # print("debug: ",key)
        if cur_id in exist_index:
小湉湉's avatar
小湉湉 已提交
317 318 319
            dic[str(cur_id) + "_sp"] = 'sp'
            cur_id += 1
            add_sp_count += 1
O
oyjxer 已提交
320
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
321
        dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
O
oyjxer 已提交
322
        cur_id += 1
小湉湉's avatar
小湉湉 已提交
323

O
oyjxer 已提交
324
    if add_sp_count + 1 == sp_count:
小湉湉's avatar
小湉湉 已提交
325 326 327
        dic[str(cur_id) + "_sp"] = 'sp'
        add_sp_count += 1

O
oyjxer 已提交
328 329
    assert add_sp_count == sp_count, "sp are not added in dic"
    return dic
P
pfZhu 已提交
330 331


小湉湉's avatar
小湉湉 已提交
332 333
def get_phns_and_spans(wav_path, old_str, new_str, source_language,
                       clone_target_language):
P
pfZhu 已提交
334 335 336
    append_new_str = (old_str == new_str[:len(old_str)])
    old_phns, mfa_start, mfa_end = [], [], []

O
oyjxer 已提交
337
    if source_language == "english":
小湉湉's avatar
小湉湉 已提交
338
        times2, word2phns = alignment(wav_path, old_str)
O
oyjxer 已提交
339
    elif source_language == "chinese":
小湉湉's avatar
小湉湉 已提交
340 341 342 343
        times2, word2phns = alignment_zh(wav_path, old_str)
        _, tp_word2phns = words2phns_zh(old_str)

        for key, value in tp_word2phns.items():
O
oyjxer 已提交
344 345
            idx, wrd = key.split('_')
            cur_val = " ".join(value)
小湉湉's avatar
小湉湉 已提交
346
            tp_word2phns[key] = cur_val
P
pfZhu 已提交
347

O
oyjxer 已提交
348
        word2phns = recover_dict(word2phns, tp_word2phns)
P
pfZhu 已提交
349

O
oyjxer 已提交
350 351
    else:
        assert source_language == "chinese" or source_language == "english", "source_language is wrong..."
P
pfZhu 已提交
352

O
oyjxer 已提交
353 354 355 356
    for item in times2:
        mfa_start.append(float(item[1]))
        mfa_end.append(float(item[2]))
        old_phns.append(item[0])
P
pfZhu 已提交
357

O
oyjxer 已提交
358
    if append_new_str and (source_language != clone_target_language):
小湉湉's avatar
小湉湉 已提交
359
        is_cross_lingual_clone = True
P
pfZhu 已提交
360
    else:
O
oyjxer 已提交
361
        is_cross_lingual_clone = False
P
pfZhu 已提交
362

O
oyjxer 已提交
363 364 365
    if is_cross_lingual_clone:
        new_str_origin = new_str[:len(old_str)]
        new_str_append = new_str[len(old_str):]
P
pfZhu 已提交
366

O
oyjxer 已提交
367
        if clone_target_language == "chinese":
小湉湉's avatar
小湉湉 已提交
368 369 370
            new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
            new_phns_append, temp_new_append_word2phns = words2phns_zh(
                new_str_append)
P
pfZhu 已提交
371

O
oyjxer 已提交
372
        elif clone_target_language == "english":
小湉湉's avatar
小湉湉 已提交
373 374 375 376
            new_phns_origin, new_origin_word2phns = words2phns_zh(
                new_str_origin)  # 原始句子
            new_phns_append, temp_new_append_word2phns = words2phns(
                new_str_append)  # clone句子
P
pfZhu 已提交
377
        else:
O
oyjxer 已提交
378
            assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it."
小湉湉's avatar
小湉湉 已提交
379

O
oyjxer 已提交
380 381 382 383
        new_phns = new_phns_origin + new_phns_append

        new_append_word2phns = {}
        length = len(new_origin_word2phns)
小湉湉's avatar
小湉湉 已提交
384
        for key, value in temp_new_append_word2phns.items():
O
oyjxer 已提交
385
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
386
            new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value
O
oyjxer 已提交
387

小湉湉's avatar
小湉湉 已提交
388 389 390 391 392
        new_word2phns = dict(
            list(new_origin_word2phns.items()) + list(
                new_append_word2phns.items()))

    else:
O
oyjxer 已提交
393 394 395 396 397 398
        if source_language == clone_target_language and clone_target_language == "english":
            new_phns, new_word2phns = words2phns(new_str)
        elif source_language == clone_target_language and clone_target_language == "chinese":
            new_phns, new_word2phns = words2phns_zh(new_str)
        else:
            assert source_language == clone_target_language, "source language is not same with target language..."
小湉湉's avatar
小湉湉 已提交
399 400 401

    span_tobe_replaced = [0, len(old_phns) - 1]
    span_tobe_added = [0, len(new_phns) - 1]
O
oyjxer 已提交
402 403 404 405 406 407
    left_index = 0
    new_phns_left = []
    sp_count = 0
    # find the left different index
    for key in word2phns.keys():
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
408 409
        if wrd == 'sp':
            sp_count += 1
O
oyjxer 已提交
410 411 412
            new_phns_left.append('sp')
        else:
            idx = str(int(idx) - sp_count)
小湉湉's avatar
小湉湉 已提交
413 414
            if idx + '_' + wrd in new_word2phns:
                left_index += len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
415
                new_phns_left.extend(word2phns[key].split())
P
pfZhu 已提交
416
            else:
O
oyjxer 已提交
417 418 419
                span_tobe_replaced[0] = len(new_phns_left)
                span_tobe_added[0] = len(new_phns_left)
                break
小湉湉's avatar
小湉湉 已提交
420

O
oyjxer 已提交
421 422 423 424 425 426 427
    # reverse word2phns and new_word2phns
    right_index = 0
    new_phns_right = []
    sp_count = 0
    word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0])
    new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0])
    new_phns_middle = []
小湉湉's avatar
小湉湉 已提交
428
    if append_new_str:
P
pfZhu 已提交
429
        new_phns_right = []
O
oyjxer 已提交
430 431 432 433 434 435 436 437
        new_phns_middle = new_phns[left_index:]
        span_tobe_replaced[0] = len(new_phns_left)
        span_tobe_added[0] = len(new_phns_left)
        span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle)
        span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
    else:
        for key in list(word2phns.keys())[::-1]:
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
438 439 440
            if wrd == 'sp':
                sp_count += 1
                new_phns_right = ['sp'] + new_phns_right
P
pfZhu 已提交
441
            else:
小湉湉's avatar
小湉湉 已提交
442 443 444 445
                idx = str(new_word2phns_max_index - (word2phns_max_index - int(
                    idx) - sp_count))
                if idx + '_' + wrd in new_word2phns:
                    right_index -= len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
446
                    new_phns_right = word2phns[key].split() + new_phns_right
P
pfZhu 已提交
447 448 449
                else:
                    span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
                    new_phns_middle = new_phns[left_index:right_index]
小湉湉's avatar
小湉湉 已提交
450 451
                    span_tobe_added[1] = len(new_phns_left) + len(
                        new_phns_middle)
P
pfZhu 已提交
452
                    if len(new_phns_middle) == 0:
小湉湉's avatar
小湉湉 已提交
453 454 455 456 457 458 459
                        span_tobe_added[1] = min(span_tobe_added[1] + 1,
                                                 len(new_phns))
                        span_tobe_added[0] = max(0, span_tobe_added[0] - 1)
                        span_tobe_replaced[0] = max(0,
                                                    span_tobe_replaced[0] - 1)
                        span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1,
                                                    len(old_phns))
O
oyjxer 已提交
460
                    break
小湉湉's avatar
小湉湉 已提交
461
    new_phns = new_phns_left + new_phns_middle + new_phns_right
O
oyjxer 已提交
462

P
pfZhu 已提交
463 464 465 466 467 468 469
    return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added


def duration_adjust_factor(original_dur, pred_dur, phns):
    length = 0
    accumulate = 0
    factor_list = []
小湉湉's avatar
小湉湉 已提交
470 471
    for ori, pred, phn in zip(original_dur, pred_dur, phns):
        if pred == 0 or phn == 'sp':
P
pfZhu 已提交
472 473
            continue
        else:
小湉湉's avatar
小湉湉 已提交
474
            factor_list.append(ori / pred)
P
pfZhu 已提交
475 476
    factor_list = np.array(factor_list)
    factor_list.sort()
小湉湉's avatar
小湉湉 已提交
477
    if len(factor_list) < 5:
P
pfZhu 已提交
478 479 480 481 482
        return 1

    length = 2
    return np.average(factor_list[length:-length])

小湉湉's avatar
小湉湉 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501

def prepare_features_with_duration(uid,
                                   prefix,
                                   clone_uid,
                                   clone_prefix,
                                   source_language,
                                   target_language,
                                   mlm_model,
                                   old_str,
                                   new_str,
                                   wav_path,
                                   duration_preditor_path,
                                   sid=None,
                                   mask_reconstruct=False,
                                   duration_adjust=True,
                                   start_end_sp=False,
                                   train_args=None):
    wav_org, rate = librosa.load(
        wav_path, sr=train_args.feats_extract_conf['fs'])
P
pfZhu 已提交
502 503
    fs = train_args.feats_extract_conf['fs']
    hop_length = train_args.feats_extract_conf['hop_length']
小湉湉's avatar
小湉湉 已提交
504 505 506

    mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(
        wav_path, old_str, new_str, source_language, target_language)
P
pfZhu 已提交
507 508

    if start_end_sp:
小湉湉's avatar
小湉湉 已提交
509 510 511
        if new_phns[-1] != 'sp':
            new_phns = new_phns + ['sp']

P
pfZhu 已提交
512
    if target_language == "english":
小湉湉's avatar
小湉湉 已提交
513 514
        old_durations = evaluate_durations(
            old_phns, target_language=target_language)
P
pfZhu 已提交
515

小湉湉's avatar
小湉湉 已提交
516
    elif target_language == "chinese":
P
pfZhu 已提交
517 518

        if source_language == "english":
小湉湉's avatar
小湉湉 已提交
519 520
            old_durations = evaluate_durations(
                old_phns, target_language=source_language)
P
pfZhu 已提交
521 522

        elif source_language == "chinese":
小湉湉's avatar
小湉湉 已提交
523 524
            old_durations = evaluate_durations(
                old_phns, target_language=source_language)
P
pfZhu 已提交
525 526 527 528

    else:
        assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..."

小湉湉's avatar
小湉湉 已提交
529
    original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
P
pfZhu 已提交
530 531 532
    if '[MASK]' in new_str:
        new_phns = old_phns
        span_tobe_added = span_tobe_replaced
小湉湉's avatar
小湉湉 已提交
533 534 535 536 537 538 539 540 541 542
        d_factor_left = duration_adjust_factor(
            original_old_durations[:span_tobe_replaced[0]],
            old_durations[:span_tobe_replaced[0]],
            old_phns[:span_tobe_replaced[0]])
        d_factor_right = duration_adjust_factor(
            original_old_durations[span_tobe_replaced[1]:],
            old_durations[span_tobe_replaced[1]:],
            old_phns[span_tobe_replaced[1]:])
        d_factor = (d_factor_left + d_factor_right) / 2
        new_durations_adjusted = [d_factor * i for i in old_durations]
P
pfZhu 已提交
543 544
    else:
        if duration_adjust:
小湉湉's avatar
小湉湉 已提交
545 546 547 548 549
            d_factor = duration_adjust_factor(original_old_durations,
                                              old_durations, old_phns)
            d_factor_paddle = duration_adjust_factor(original_old_durations,
                                                     old_durations, old_phns)
            d_factor = d_factor * 1.25
P
pfZhu 已提交
550 551 552
        else:
            d_factor = 1

小湉湉's avatar
小湉湉 已提交
553 554 555
        if target_language == "english":
            new_durations = evaluate_durations(
                new_phns, target_language=target_language)
P
pfZhu 已提交
556

小湉湉's avatar
小湉湉 已提交
557 558 559
        elif target_language == "chinese":
            new_durations = evaluate_durations(
                new_phns, target_language=target_language)
P
pfZhu 已提交
560

小湉湉's avatar
小湉湉 已提交
561
        new_durations_adjusted = [d_factor * i for i in new_durations]
P
pfZhu 已提交
562

小湉湉's avatar
小湉湉 已提交
563 564 565 566 567 568
        if span_tobe_replaced[0] < len(old_phns) and old_phns[
                span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]:
            new_durations_adjusted[span_tobe_added[0]] = original_old_durations[
                span_tobe_replaced[0]]
        if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len(
                new_phns):
P
pfZhu 已提交
569
            if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]:
小湉湉's avatar
小湉湉 已提交
570 571 572 573 574 575 576
                new_durations_adjusted[span_tobe_added[
                    1]] = original_old_durations[span_tobe_replaced[1]]
    new_span_duration_sum = sum(
        new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]])
    old_span_duration_sum = sum(
        original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]])
    duration_offset = new_span_duration_sum - old_span_duration_sum
P
pfZhu 已提交
577 578 579
    new_mfa_start = mfa_start[:span_tobe_replaced[0]]
    new_mfa_end = mfa_end[:span_tobe_replaced[0]]
    for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]:
小湉湉's avatar
小湉湉 已提交
580
        if len(new_mfa_end) == 0:
P
pfZhu 已提交
581 582 583 584
            new_mfa_start.append(0)
            new_mfa_end.append(i)
        else:
            new_mfa_start.append(new_mfa_end[-1])
小湉湉's avatar
小湉湉 已提交
585 586 587 588 589 590 591 592
            new_mfa_end.append(new_mfa_end[-1] + i)
    new_mfa_start += [
        i + duration_offset for i in mfa_start[span_tobe_replaced[1]:]
    ]
    new_mfa_end += [
        i + duration_offset for i in mfa_end[span_tobe_replaced[1]:]
    ]

P
pfZhu 已提交
593
    # 3. get new wav 
小湉湉's avatar
小湉湉 已提交
594
    if span_tobe_replaced[0] >= len(mfa_start):
P
pfZhu 已提交
595 596 597
        left_index = len(wav_org)
        right_index = left_index
    else:
小湉湉's avatar
小湉湉 已提交
598 599 600 601 602 603
        left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs))
        right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs))
    new_blank_wav = np.zeros(
        (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
    new_wav_org = np.concatenate(
        [wav_org[:left_index], new_blank_wav, wav_org[right_index:]])
P
pfZhu 已提交
604 605

    # 4. get old and new mel span to be mask
小湉湉's avatar
小湉湉 已提交
606 607 608 609 610 611
    old_span_boundary = get_masked_mel_boundary(
        mfa_start, mfa_end, fs, hop_length, span_tobe_replaced)  # [92, 92]
    new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs,
                                                hop_length,
                                                span_tobe_added)  # [92, 174]

P
pfZhu 已提交
612 613
    return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary

小湉湉's avatar
小湉湉 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651

def prepare_features(uid,
                     prefix,
                     clone_uid,
                     clone_prefix,
                     source_language,
                     target_language,
                     mlm_model,
                     processor,
                     wav_path,
                     old_str,
                     new_str,
                     duration_preditor_path,
                     sid=None,
                     duration_adjust=True,
                     start_end_sp=False,
                     mask_reconstruct=False,
                     train_args=None):
    wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(
        uid,
        prefix,
        clone_uid,
        clone_prefix,
        source_language,
        target_language,
        mlm_model,
        old_str,
        new_str,
        wav_path,
        duration_preditor_path,
        sid=sid,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        mask_reconstruct=mask_reconstruct,
        train_args=train_args)
    speech = np.array(wav_org, dtype=np.float32)
    align_start = np.array(mfa_start)
    align_end = np.array(mfa_end)
P
pfZhu 已提交
652
    token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
小湉湉's avatar
小湉湉 已提交
653 654 655
    text = np.array(
        list(
            map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
P
pfZhu 已提交
656 657 658
    # print('unk id is', token_to_id['<unk>'])
    # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
    span_boundary = np.array(new_span_boundary)
小湉湉's avatar
小湉湉 已提交
659 660 661 662 663 664 665 666
    batch = [('1', {
        "speech": speech,
        "align_start": align_start,
        "align_end": align_end,
        "text": text,
        "span_boundary": span_boundary
    })]

P
pfZhu 已提交
667 668 669
    return batch, old_span_boundary, new_span_boundary


小湉湉's avatar
小湉湉 已提交
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
def decode_with_model(uid,
                      prefix,
                      clone_uid,
                      clone_prefix,
                      source_language,
                      target_language,
                      mlm_model,
                      processor,
                      collate_fn,
                      wav_path,
                      old_str,
                      new_str,
                      duration_preditor_path,
                      sid=None,
                      decoder=False,
                      use_teacher_forcing=False,
                      duration_adjust=True,
                      start_end_sp=False,
                      train_args=None):
    fs, hop_length = train_args.feats_extract_conf[
        'fs'], train_args.feats_extract_conf['hop_length']

    batch, old_span_boundary, new_span_boundary = prepare_features(
        uid,
        prefix,
        clone_uid,
        clone_prefix,
        source_language,
        target_language,
        mlm_model,
        processor,
        wav_path,
        old_str,
        new_str,
        duration_preditor_path,
        sid,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        train_args=train_args)

O
oyjxer 已提交
710 711
    feats = collate_fn(batch)[1]

P
pfZhu 已提交
712 713 714 715
    if 'text_masked_position' in feats.keys():
        feats.pop('text_masked_position')
    for k, v in feats.items():
        feats[k] = paddle.to_tensor(v)
小湉湉's avatar
小湉湉 已提交
716 717 718 719 720
    rtn = mlm_model.inference(
        **feats,
        span_boundary=new_span_boundary,
        use_teacher_forcing=use_teacher_forcing)
    output = rtn['feat_gen']
P
pfZhu 已提交
721
    if 0 in output[0].shape and 0 not in output[-1].shape:
小湉湉's avatar
小湉湉 已提交
722 723
        output_feat = paddle.concat(
            output[1:-1] + [output[-1].squeeze()], axis=0).cpu()
P
pfZhu 已提交
724
    elif 0 not in output[0].shape and 0 in output[-1].shape:
小湉湉's avatar
小湉湉 已提交
725 726
        output_feat = paddle.concat(
            [output[0].squeeze()] + output[1:-1], axis=0).cpu()
P
pfZhu 已提交
727 728 729
    elif 0 in output[0].shape and 0 in output[-1].shape:
        output_feat = paddle.concat(output[1:-1], axis=0).cpu()
    else:
小湉湉's avatar
小湉湉 已提交
730 731 732 733 734 735 736 737
        output_feat = paddle.concat(
            [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
            axis=0).cpu()

    wav_org, rate = librosa.load(
        wav_path, sr=train_args.feats_extract_conf['fs'])
    origin_speech = paddle.to_tensor(
        np.array(wav_org, dtype=np.float32)).unsqueeze(0)
P
pfZhu 已提交
738 739 740 741
    speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0)
    return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length


O
oyjxer 已提交
742 743 744
class MLMCollateFn:
    """Functor class of common_collate_fn()"""

小湉湉's avatar
小湉湉 已提交
745 746 747 748 749 750 751 752 753 754 755 756 757 758
    def __init__(self,
                 feats_extract,
                 float_pad_value: Union[float, int]=0.0,
                 int_pad_value: int=-32768,
                 not_sequence: Collection[str]=(),
                 mlm_prob: float=0.8,
                 mean_phn_span: int=8,
                 attention_window: int=0,
                 pad_speech: bool=False,
                 sega_emb: bool=False,
                 duration_collect: bool=False,
                 text_masking: bool=False):
        self.mlm_prob = mlm_prob
        self.mean_phn_span = mean_phn_span
O
oyjxer 已提交
759 760 761 762
        self.feats_extract = feats_extract
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
        self.not_sequence = set(not_sequence)
小湉湉's avatar
小湉湉 已提交
763 764 765
        self.attention_window = attention_window
        self.pad_speech = pad_speech
        self.sega_emb = sega_emb
O
oyjxer 已提交
766 767 768 769
        self.duration_collect = duration_collect
        self.text_masking = text_masking

    def __repr__(self):
小湉湉's avatar
小湉湉 已提交
770 771 772 773 774
        return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
                f"int_pad_value={self.float_pad_value})")

    def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
                 ) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
O
oyjxer 已提交
775 776 777 778 779
        return mlm_collate_fn(
            data,
            float_pad_value=self.float_pad_value,
            int_pad_value=self.int_pad_value,
            not_sequence=self.not_sequence,
小湉湉's avatar
小湉湉 已提交
780
            mlm_prob=self.mlm_prob,
O
oyjxer 已提交
781 782 783 784 785 786
            mean_phn_span=self.mean_phn_span,
            feats_extract=self.feats_extract,
            attention_window=self.attention_window,
            pad_speech=self.pad_speech,
            sega_emb=self.sega_emb,
            duration_collect=self.duration_collect,
小湉湉's avatar
小湉湉 已提交
787 788
            text_masking=self.text_masking)

O
oyjxer 已提交
789 790

def mlm_collate_fn(
小湉湉's avatar
小湉湉 已提交
791 792 793 794 795 796 797 798 799 800 801 802
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int]=0.0,
        int_pad_value: int=-32768,
        not_sequence: Collection[str]=(),
        mlm_prob: float=0.8,
        mean_phn_span: int=8,
        feats_extract=None,
        attention_window: int=0,
        pad_speech: bool=False,
        sega_emb: bool=False,
        duration_collect: bool=False,
        text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
O
oyjxer 已提交
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
    """Concatenate ndarray-list to an array and convert to torch.Tensor.

    Examples:
        >>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
        >>> import espnet2.tasks.abs_task
        >>> from espnet2.train.dataset import ESPnetDataset
        >>> sampler = ConstantBatchSampler(...)
        >>> dataset = ESPnetDataset(...)
        >>> keys = next(iter(sampler)
        >>> batch = [dataset[key] for key in keys]
        >>> batch = common_collate_fn(batch)
        >>> model(**batch)

        Note that the dict-keys of batch are propagated from
        that of the dataset as they are.

    """
    uttids = [u for u, _ in data]
    data = [d for _, d in data]

    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
小湉湉's avatar
小湉湉 已提交
824 825
    assert all(not k.endswith("_lengths")
               for k in data[0]), f"*_lengths is reserved: {list(data[0])}"
O
oyjxer 已提交
826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847

    output = {}
    for key in data[0]:
        # NOTE(kamo):
        # Each models, which accepts these values finally, are responsible
        # to repaint the pad_value to the desired value for each tasks.
        if data[0][key].dtype.kind == "i":
            pad_value = int_pad_value
        else:
            pad_value = float_pad_value

        array_list = [d[key] for d in data]

        # Assume the first axis is length:
        # tensor_list: Batch x (Length, ...)
        tensor_list = [paddle.to_tensor(a) for a in array_list]
        # tensor: (Batch, Length, ...)
        tensor = pad_list(tensor_list, pad_value)
        output[key] = tensor

        # lens: (Batch,)
        if key not in not_sequence:
小湉湉's avatar
小湉湉 已提交
848 849
            lens = paddle.to_tensor(
                [d[key].shape[0] for d in data], dtype=paddle.long)
O
oyjxer 已提交
850 851 852 853 854 855 856 857 858
            output[key + "_lengths"] = lens

    feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
    feats = paddle.to_tensor(feats)
    # print('out shape', paddle.shape(feats))
    feats_lengths = paddle.shape(feats)[0]
    feats = paddle.unsqueeze(feats, 0)
    batch_size = paddle.shape(feats)[0]
    if 'text' not in output:
小湉湉's avatar
小湉湉 已提交
859 860 861 862 863 864 865 866
        text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2
        text_lengths = paddle.zeros_like(feats_lengths) + 1
        max_tlen = 1
        align_start = paddle.zeros_like(text)
        align_end = paddle.zeros_like(text)
        align_start_lengths = paddle.zeros_like(feats_lengths)
        align_end_lengths = paddle.zeros_like(feats_lengths)
        sega_emb = False
O
oyjxer 已提交
867 868 869 870
        mean_phn_span = 0
        mlm_prob = 0.15
    else:
        text, text_lengths = output["text"], output["text_lengths"]
小湉湉's avatar
小湉湉 已提交
871 872 873 874 875 876 877
        align_start, align_start_lengths, align_end, align_end_lengths = output[
            "align_start"], output["align_start_lengths"], output[
                "align_end"], output["align_end_lengths"]
        align_start = paddle.floor(feats_extract.sr * align_start /
                                   feats_extract.hop_length).int()
        align_end = paddle.floor(feats_extract.sr * align_end /
                                 feats_extract.hop_length).int()
O
oyjxer 已提交
878 879
        max_tlen = max(text_lengths).item()
    max_slen = max(feats_lengths).item()
小湉湉's avatar
小湉湉 已提交
880 881 882 883
    speech_pad = feats[:, :max_slen]
    if attention_window > 0 and pad_speech:
        speech_pad, max_slen = pad_to_longformer_att_window(
            speech_pad, max_slen, max_slen, attention_window)
O
oyjxer 已提交
884
    max_len = max_slen + max_tlen
小湉湉's avatar
小湉湉 已提交
885 886 887
    if attention_window > 0:
        text_pad, max_tlen = pad_to_longformer_att_window(
            text, max_len, max_tlen, attention_window)
O
oyjxer 已提交
888 889
    else:
        text_pad = text
小湉湉's avatar
小湉湉 已提交
890 891 892 893 894 895
    text_mask = make_non_pad_mask(
        text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2)
    if attention_window > 0:
        text_mask = text_mask * 2
    speech_mask = make_non_pad_mask(
        feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
O
oyjxer 已提交
896 897 898 899 900
    span_boundary = None
    if 'span_boundary' in output.keys():
        span_boundary = output['span_boundary']

    if text_masking:
小湉湉's avatar
小湉湉 已提交
901 902 903
        masked_position, text_masked_position, _ = phones_text_masking(
            speech_pad, speech_mask, text_pad, text_mask, align_start,
            align_end, align_start_lengths, mlm_prob, mean_phn_span,
O
oyjxer 已提交
904 905 906 907
            span_boundary)
    else:
        text_masked_position = np.zeros(text_pad.size())
        masked_position, _ = phones_masking(
小湉湉's avatar
小湉湉 已提交
908 909
            speech_pad, speech_mask, align_start, align_end,
            align_start_lengths, mlm_prob, mean_phn_span, span_boundary)
O
oyjxer 已提交
910 911 912

    output_dict = {}
    if duration_collect and 'text' in output:
小湉湉's avatar
小湉湉 已提交
913 914 915 916 917 918 919
        reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration(
            speech_pad, text_pad, align_start, align_end, align_start_lengths,
            sega_emb, masked_position, feats_lengths)
        speech_mask = make_non_pad_mask(
            feats_lengths.tolist(),
            speech_pad[:, :reordered_index.shape[1], 0],
            length_dim=1).unsqueeze(-2)
O
oyjxer 已提交
920 921 922
        output_dict['durations'] = durations
        output_dict['reordered_index'] = reordered_index
    else:
小湉湉's avatar
小湉湉 已提交
923 924 925
        speech_segment_pos, text_segment_pos = get_segment_pos(
            speech_pad, text_pad, align_start, align_end, align_start_lengths,
            sega_emb)
O
oyjxer 已提交
926 927 928 929 930 931 932 933 934 935 936 937 938
    output_dict['speech'] = speech_pad
    output_dict['text'] = text_pad
    output_dict['masked_position'] = masked_position
    output_dict['text_masked_position'] = text_masked_position
    output_dict['speech_mask'] = speech_mask
    output_dict['text_mask'] = text_mask
    output_dict['speech_segment_pos'] = speech_segment_pos
    output_dict['text_segment_pos'] = text_segment_pos
    output_dict['speech_lengths'] = output["speech_lengths"]
    output_dict['text_lengths'] = text_lengths
    output = (uttids, output_dict)
    return output

小湉湉's avatar
小湉湉 已提交
939 940

def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
O
oyjxer 已提交
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963
    # -> Callable[
    #     [Collection[Tuple[str, Dict[str, np.ndarray]]]],
    #     Tuple[List[str], Dict[str, torch.Tensor]],
    # ]:

    # assert check_argument_types()
    # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
    feats_extract_class = LogMelFBank
    if args.feats_extract_conf['win_length'] == None:
        args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']

    args_dic = {}
    for k, v in args.feats_extract_conf.items():
        if k == 'fs':
            args_dic['sr'] = v
        else:
            args_dic[k] = v
    # feats_extract = feats_extract_class(**args.feats_extract_conf)
    feats_extract = feats_extract_class(**args_dic)

    sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
    if args.encoder_conf['selfattention_layer_type'] == 'longformer':
        attention_window = args.encoder_conf['attention_window']
小湉湉's avatar
小湉湉 已提交
964 965
        pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
            'pre_speech_layer'] > 0 else False
O
oyjxer 已提交
966
    else:
小湉湉's avatar
小湉湉 已提交
967 968 969
        attention_window = 0
        pad_speech = False
    if epoch == -1:
O
oyjxer 已提交
970 971 972
        mlm_prob_factor = 1
    else:
        mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5]
小湉湉's avatar
小湉湉 已提交
973 974 975 976
        mlm_prob_factor = 0.8  #mlm_probs[epoch // 100]
    if 'duration_predictor_layers' in args.model_conf.keys(
    ) and args.model_conf['duration_predictor_layers'] > 0:
        duration_collect = True
O
oyjxer 已提交
977
    else:
小湉湉's avatar
小湉湉 已提交
978
        duration_collect = False
O
oyjxer 已提交
979

小湉湉's avatar
小湉湉 已提交
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
    return MLMCollateFn(
        feats_extract,
        float_pad_value=0.0,
        int_pad_value=0,
        mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
        mean_phn_span=args.model_conf['mean_phn_span'],
        attention_window=attention_window,
        pad_speech=pad_speech,
        sega_emb=sega_emb,
        duration_collect=duration_collect)


def get_mlm_output(uid,
                   prefix,
                   clone_uid,
                   clone_prefix,
                   source_language,
                   target_language,
                   model_name,
                   wav_path,
                   old_str,
                   new_str,
                   duration_preditor_path,
                   sid=None,
                   decoder=False,
                   use_teacher_forcing=False,
                   dynamic_eval=(0, 0),
                   duration_adjust=True,
                   start_end_sp=False):
    mlm_model, train_args = load_model(model_name)
P
pfZhu 已提交
1010 1011
    mlm_model.eval()
    processor = None
O
oyjxer 已提交
1012
    collate_fn = build_collate_fn(train_args, False)
P
pfZhu 已提交
1013

小湉湉's avatar
小湉湉 已提交
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
    return decode_with_model(
        uid,
        prefix,
        clone_uid,
        clone_prefix,
        source_language,
        target_language,
        mlm_model,
        processor,
        collate_fn,
        wav_path,
        old_str,
        new_str,
        duration_preditor_path,
        sid=sid,
        decoder=decoder,
        use_teacher_forcing=use_teacher_forcing,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        train_args=train_args)


def test_vctk(uid,
              clone_uid,
              clone_prefix,
              source_language,
              target_language,
              vocoder,
              prefix='dump/raw/dev',
              model_name="conformer",
              old_str="",
              new_str="",
              prompt_decoding=False,
              dynamic_eval=(0, 0),
              task_name=None):
P
pfZhu 已提交
1049 1050 1051

    duration_preditor_path = None
    spemd = None
小湉湉's avatar
小湉湉 已提交
1052 1053
    full_origin_str, wav_path = read_data(uid, prefix)

O
oyjxer 已提交
1054 1055 1056
    if task_name == 'edit':
        new_str = new_str
    elif task_name == 'synthesize':
小湉湉's avatar
小湉湉 已提交
1057
        new_str = full_origin_str + new_str
O
oyjxer 已提交
1058
    else:
小湉湉's avatar
小湉湉 已提交
1059 1060 1061
        new_str = full_origin_str + ' '.join(
            [ch for ch in new_str if is_chinese(ch)])

P
pfZhu 已提交
1062
    print('new_str is ', new_str)
小湉湉's avatar
小湉湉 已提交
1063

P
pfZhu 已提交
1064 1065 1066
    if not old_str:
        old_str = full_origin_str

小湉湉's avatar
小湉湉 已提交
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
    results_dict, old_span = plot_mel_and_vocode_wav(
        uid,
        prefix,
        clone_uid,
        clone_prefix,
        source_language,
        target_language,
        model_name,
        wav_path,
        full_origin_str,
        old_str,
        new_str,
        vocoder,
        duration_preditor_path,
        sid=spemd)
P
pfZhu 已提交
1082 1083
    return results_dict

小湉湉's avatar
小湉湉 已提交
1084

P
pfZhu 已提交
1085
if __name__ == "__main__":
小湉湉's avatar
小湉湉 已提交
1086
    # parse config and args
P
pfZhu 已提交
1087
    args = parse_args()
小湉湉's avatar
小湉湉 已提交
1088 1089 1090 1091 1092 1093 1094

    data_dict = test_vctk(
        args.uid,
        args.clone_uid,
        args.clone_prefix,
        args.source_language,
        args.target_language,
O
oyjxer 已提交
1095
        args.use_pt_vocoder,
小湉湉's avatar
小湉湉 已提交
1096
        args.prefix,
P
pfZhu 已提交
1097 1098 1099
        args.model_name,
        new_str=args.new_str,
        task_name=args.task_name)
小湉湉's avatar
小湉湉 已提交
1100
    sf.write(args.output_name, data_dict['output'], samplerate=24000)
O
oyjxer 已提交
1101
    print("finished...")
P
pfZhu 已提交
1102
    # exit()