inference.py 22.8 KB
Newer Older
P
pfZhu 已提交
1
#!/usr/bin/env python3
小湉湉's avatar
小湉湉 已提交
2 3 4
import os
import random
from pathlib import Path
P
pfZhu 已提交
5 6
from typing import Dict
from typing import List
小湉湉's avatar
小湉湉 已提交
7 8 9

import librosa
import numpy as np
O
oyjxer 已提交
10
import paddle
小湉湉's avatar
小湉湉 已提交
11
import soundfile as sf
O
oyjxer 已提交
12
import torch
小湉湉's avatar
小湉湉 已提交
13
from paddle import nn
小湉湉's avatar
小湉湉 已提交
14 15 16 17 18
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
小湉湉's avatar
小湉湉 已提交
19 20 21 22 23 24 25 26 27 28 29

from align import alignment
from align import alignment_zh
from align import words2phns
from align import words2phns_zh
from collect_fn import build_collate_fn
from mlm import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model

P
pfZhu 已提交
30 31 32
random.seed(0)
np.random.seed(0)

O
oyjxer 已提交
33

小湉湉's avatar
小湉湉 已提交
34
def plot_mel_and_vocode_wav(wav_path: str,
小湉湉's avatar
小湉湉 已提交
35 36
                            source_lang: str='english',
                            target_lang: str='english',
小湉湉's avatar
小湉湉 已提交
37
                            model_name: str="paddle_checkpoint_en",
小湉湉's avatar
小湉湉 已提交
38 39 40 41
                            old_str: str="",
                            new_str: str="",
                            use_pt_vocoder: bool=False,
                            non_autoreg: bool=True):
小湉湉's avatar
小湉湉 已提交
42
    wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
小湉湉's avatar
小湉湉 已提交
43 44 45 46 47 48
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
49
        use_teacher_forcing=non_autoreg)
小湉湉's avatar
小湉湉 已提交
50

小湉湉's avatar
小湉湉 已提交
51
    masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
小湉湉's avatar
小湉湉 已提交
52

小湉湉's avatar
小湉湉 已提交
53
    if target_lang == 'english':
O
oyjxer 已提交
54
        if use_pt_vocoder:
小湉湉's avatar
小湉湉 已提交
55
            output_feat = output_feat.cpu().numpy()
小湉湉's avatar
小湉湉 已提交
56
            output_feat = torch.tensor(output_feat, dtype=torch.float)
O
oyjxer 已提交
57
            vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
小湉湉's avatar
小湉湉 已提交
58
            replaced_wav = vocoder(output_feat).cpu().numpy()
O
oyjxer 已提交
59
        else:
小湉湉's avatar
小湉湉 已提交
60
            replaced_wav = get_voc_out(output_feat)
P
pfZhu 已提交
61

小湉湉's avatar
小湉湉 已提交
62
    elif target_lang == 'chinese':
小湉湉's avatar
小湉湉 已提交
63
        replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat)
小湉湉's avatar
小湉湉 已提交
64

小湉湉's avatar
小湉湉 已提交
65 66
    old_time_bdy = [hop_length * x for x in old_span_bdy]
    new_time_bdy = [hop_length * x for x in new_span_bdy]
小湉湉's avatar
小湉湉 已提交
67

小湉湉's avatar
小湉湉 已提交
68
    if target_lang == 'english':
小湉湉's avatar
小湉湉 已提交
69
        wav_org_replaced_paddle_voc = np.concatenate([
小湉湉's avatar
小湉湉 已提交
70 71 72
            wav_org[:old_time_bdy[0]],
            replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
            wav_org[old_time_bdy[1]:]
小湉湉's avatar
小湉湉 已提交
73
        ])
P
pfZhu 已提交
74

小湉湉's avatar
小湉湉 已提交
75
        data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
P
pfZhu 已提交
76

小湉湉's avatar
小湉湉 已提交
77
    elif target_lang == 'chinese':
小湉湉's avatar
小湉湉 已提交
78
        wav_org_replaced_only_mask_fst2_voc = np.concatenate([
小湉湉's avatar
小湉湉 已提交
79 80
            wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
            wav_org[old_time_bdy[1]:]
小湉湉's avatar
小湉湉 已提交
81
        ])
P
pfZhu 已提交
82
        data_dict = {
小湉湉's avatar
小湉湉 已提交
83 84 85
            "origin": wav_org,
            "output": wav_org_replaced_only_mask_fst2_voc,
        }
P
pfZhu 已提交
86

小湉湉's avatar
小湉湉 已提交
87
    return data_dict, old_span_bdy
P
pfZhu 已提交
88

O
oyjxer 已提交
89

小湉湉's avatar
小湉湉 已提交
90
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
O
oyjxer 已提交
91 92 93
    vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
    vocoder_file = download_pretrained_model(vocoder_tag)
    vocoder_config = Path(vocoder_file).parent / "config.yml"
小湉湉's avatar
小湉湉 已提交
94
    vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
O
oyjxer 已提交
95 96
    return vocoder

小湉湉's avatar
小湉湉 已提交
97

小湉湉's avatar
小湉湉 已提交
98
def load_model(model_name: str="paddle_checkpoint_en"):
小湉湉's avatar
小湉湉 已提交
99
    config_path = './pretrained_model/{}/config.yaml'.format(model_name)
P
pfZhu 已提交
100
    model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
小湉湉's avatar
小湉湉 已提交
101
    mlm_model, conf = build_model_from_file(
小湉湉's avatar
小湉湉 已提交
102
        config_file=config_path, model_file=model_path)
小湉湉's avatar
小湉湉 已提交
103
    return mlm_model, conf
P
pfZhu 已提交
104 105


小湉湉's avatar
小湉湉 已提交
106 107 108 109 110 111 112
def read_data(uid: str, prefix: os.PathLike):
    # 获取 uid 对应的文本
    mfa_text = read_2col_text(prefix + '/text')[uid]
    # 获取 uid 对应的音频路径
    mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
    if not os.path.isabs(mfa_wav_path):
        mfa_wav_path = prefix + mfa_wav_path
P
pfZhu 已提交
113
    return mfa_text, mfa_wav_path
小湉湉's avatar
小湉湉 已提交
114 115


小湉湉's avatar
小湉湉 已提交
116
def get_align_data(uid: str, prefix: os.PathLike):
小湉湉's avatar
小湉湉 已提交
117
    mfa_path = prefix + "mfa_"
小湉湉's avatar
小湉湉 已提交
118
    mfa_text = read_2col_text(mfa_path + 'text')[uid]
小湉湉's avatar
小湉湉 已提交
119 120 121 122
    mfa_start = load_num_sequence_text(
        mfa_path + 'start', loader_type='text_float')[uid]
    mfa_end = load_num_sequence_text(
        mfa_path + 'end', loader_type='text_float')[uid]
小湉湉's avatar
小湉湉 已提交
123
    mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
P
pfZhu 已提交
124 125 126
    return mfa_text, mfa_start, mfa_end, mfa_wav_path


小湉湉's avatar
小湉湉 已提交
127
# 获取需要被 mask 的 mel 帧的范围
小湉湉's avatar
小湉湉 已提交
128 129 130 131 132
def get_masked_mel_bdy(mfa_start: List[float],
                       mfa_end: List[float],
                       fs: int,
                       hop_length: int,
                       span_to_repl: List[List[int]]):
小湉湉's avatar
小湉湉 已提交
133 134 135 136
    align_start = np.array(mfa_start)
    align_end = np.array(mfa_end)
    align_start = np.floor(fs * align_start / hop_length).astype('int')
    align_end = np.floor(fs * align_end / hop_length).astype('int')
小湉湉's avatar
小湉湉 已提交
137
    if span_to_repl[0] >= len(mfa_start):
小湉湉's avatar
小湉湉 已提交
138
        span_bdy = [align_end[-1], align_end[-1]]
P
pfZhu 已提交
139
    else:
小湉湉's avatar
小湉湉 已提交
140
        span_bdy = [
小湉湉's avatar
小湉湉 已提交
141
            align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
小湉湉's avatar
小湉湉 已提交
142
        ]
小湉湉's avatar
小湉湉 已提交
143
    return span_bdy, align_start, align_end
P
pfZhu 已提交
144 145


小湉湉's avatar
小湉湉 已提交
146
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
O
oyjxer 已提交
147
    dic = {}
小湉湉's avatar
小湉湉 已提交
148 149
    keys_to_del = []
    exist_idx = []
小湉湉's avatar
小湉湉 已提交
150 151
    sp_count = 0
    add_sp_count = 0
O
oyjxer 已提交
152 153 154
    for key in word2phns.keys():
        idx, wrd = key.split('_')
        if wrd == 'sp':
小湉湉's avatar
小湉湉 已提交
155
            sp_count += 1
小湉湉's avatar
小湉湉 已提交
156
            exist_idx.append(int(idx))
P
pfZhu 已提交
157
        else:
小湉湉's avatar
小湉湉 已提交
158
            keys_to_del.append(key)
小湉湉's avatar
小湉湉 已提交
159

小湉湉's avatar
小湉湉 已提交
160
    for key in keys_to_del:
O
oyjxer 已提交
161 162 163 164
        del word2phns[key]

    cur_id = 0
    for key in tp_word2phns.keys():
小湉湉's avatar
小湉湉 已提交
165
        if cur_id in exist_idx:
小湉湉's avatar
小湉湉 已提交
166 167 168
            dic[str(cur_id) + "_sp"] = 'sp'
            cur_id += 1
            add_sp_count += 1
O
oyjxer 已提交
169
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
170
        dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
O
oyjxer 已提交
171
        cur_id += 1
小湉湉's avatar
小湉湉 已提交
172

O
oyjxer 已提交
173
    if add_sp_count + 1 == sp_count:
小湉湉's avatar
小湉湉 已提交
174 175 176
        dic[str(cur_id) + "_sp"] = 'sp'
        add_sp_count += 1

O
oyjxer 已提交
177 178
    assert add_sp_count == sp_count, "sp are not added in dic"
    return dic
P
pfZhu 已提交
179 180


小湉湉's avatar
小湉湉 已提交
181 182 183 184
def get_max_idx(dic):
    return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]


小湉湉's avatar
小湉湉 已提交
185 186 187 188 189
def get_phns_and_spans(wav_path: str,
                       old_str: str="",
                       new_str: str="",
                       source_lang: str="english",
                       target_lang: str="english"):
小湉湉's avatar
小湉湉 已提交
190
    is_append = (old_str == new_str[:len(old_str)])
P
pfZhu 已提交
191
    old_phns, mfa_start, mfa_end = [], [], []
小湉湉's avatar
小湉湉 已提交
192
    # source
小湉湉's avatar
小湉湉 已提交
193
    if source_lang == "english":
小湉湉's avatar
小湉湉 已提交
194
        intervals, word2phns = alignment(wav_path, old_str)
小湉湉's avatar
小湉湉 已提交
195
    elif source_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
196
        intervals, word2phns = alignment_zh(wav_path, old_str)
小湉湉's avatar
小湉湉 已提交
197 198 199
        _, tp_word2phns = words2phns_zh(old_str)

        for key, value in tp_word2phns.items():
O
oyjxer 已提交
200 201
            idx, wrd = key.split('_')
            cur_val = " ".join(value)
小湉湉's avatar
小湉湉 已提交
202
            tp_word2phns[key] = cur_val
P
pfZhu 已提交
203

O
oyjxer 已提交
204 205
        word2phns = recover_dict(word2phns, tp_word2phns)
    else:
小湉湉's avatar
小湉湉 已提交
206 207
        assert source_lang == "chinese" or source_lang == "english", \
            "source_lang is wrong..."
P
pfZhu 已提交
208

小湉湉's avatar
小湉湉 已提交
209 210
    for item in intervals:
        old_phns.append(item[0])
O
oyjxer 已提交
211 212
        mfa_start.append(float(item[1]))
        mfa_end.append(float(item[2]))
小湉湉's avatar
小湉湉 已提交
213 214 215
    # target
    if is_append and (source_lang != target_lang):
        cross_lingual_clone = True
P
pfZhu 已提交
216
    else:
小湉湉's avatar
小湉湉 已提交
217
        cross_lingual_clone = False
P
pfZhu 已提交
218

小湉湉's avatar
小湉湉 已提交
219 220 221
    if cross_lingual_clone:
        str_origin = new_str[:len(old_str)]
        str_append = new_str[len(old_str):]
P
pfZhu 已提交
222

小湉湉's avatar
小湉湉 已提交
223
        if target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
224 225
            phns_origin, origin_word2phns = words2phns(str_origin)
            phns_append, append_word2phns_tmp = words2phns_zh(str_append)
P
pfZhu 已提交
226

小湉湉's avatar
小湉湉 已提交
227 228
        elif target_lang == "english":
            # 原始句子
小湉湉's avatar
小湉湉 已提交
229 230 231
            phns_origin, origin_word2phns = words2phns_zh(str_origin)
            # clone 句子 
            phns_append, append_word2phns_tmp = words2phns(str_append)
P
pfZhu 已提交
232
        else:
小湉湉's avatar
小湉湉 已提交
233 234
            assert target_lang == "chinese" or target_lang == "english", \
                "cloning is not support for this language, please check it."
小湉湉's avatar
小湉湉 已提交
235

小湉湉's avatar
小湉湉 已提交
236
        new_phns = phns_origin + phns_append
O
oyjxer 已提交
237

小湉湉's avatar
小湉湉 已提交
238 239 240
        append_word2phns = {}
        length = len(origin_word2phns)
        for key, value in append_word2phns_tmp.items():
O
oyjxer 已提交
241
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
242 243 244
            append_word2phns[str(int(idx) + length) + '_' + wrd] = value
        new_word2phns = origin_word2phns.copy()
        new_word2phns.update(append_word2phns)
小湉湉's avatar
小湉湉 已提交
245 246

    else:
小湉湉's avatar
小湉湉 已提交
247
        if source_lang == target_lang and target_lang == "english":
O
oyjxer 已提交
248
            new_phns, new_word2phns = words2phns(new_str)
小湉湉's avatar
小湉湉 已提交
249
        elif source_lang == target_lang and target_lang == "chinese":
O
oyjxer 已提交
250 251
            new_phns, new_word2phns = words2phns_zh(new_str)
        else:
小湉湉's avatar
小湉湉 已提交
252 253
            assert source_lang == target_lang, \
                "source language is not same with target language..."
小湉湉's avatar
小湉湉 已提交
254

小湉湉's avatar
小湉湉 已提交
255 256 257
    span_to_repl = [0, len(old_phns) - 1]
    span_to_add = [0, len(new_phns) - 1]
    left_idx = 0
O
oyjxer 已提交
258 259 260 261 262
    new_phns_left = []
    sp_count = 0
    # find the left different index
    for key in word2phns.keys():
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
263 264
        if wrd == 'sp':
            sp_count += 1
O
oyjxer 已提交
265 266 267
            new_phns_left.append('sp')
        else:
            idx = str(int(idx) - sp_count)
小湉湉's avatar
小湉湉 已提交
268
            if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
269
                left_idx += len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
270
                new_phns_left.extend(word2phns[key].split())
P
pfZhu 已提交
271
            else:
小湉湉's avatar
小湉湉 已提交
272 273
                span_to_repl[0] = len(new_phns_left)
                span_to_add[0] = len(new_phns_left)
O
oyjxer 已提交
274
                break
小湉湉's avatar
小湉湉 已提交
275

O
oyjxer 已提交
276
    # reverse word2phns and new_word2phns
小湉湉's avatar
小湉湉 已提交
277
    right_idx = 0
O
oyjxer 已提交
278 279
    new_phns_right = []
    sp_count = 0
小湉湉's avatar
小湉湉 已提交
280 281
    word2phns_max_idx = get_max_idx(word2phns)
    new_word2phns_max_idx = get_max_idx(new_word2phns)
小湉湉's avatar
小湉湉 已提交
282
    new_phns_mid = []
小湉湉's avatar
小湉湉 已提交
283
    if is_append:
P
pfZhu 已提交
284
        new_phns_right = []
小湉湉's avatar
小湉湉 已提交
285 286 287 288 289
        new_phns_mid = new_phns[left_idx:]
        span_to_repl[0] = len(new_phns_left)
        span_to_add[0] = len(new_phns_left)
        span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
        span_to_repl[1] = len(old_phns) - len(new_phns_right)
小湉湉's avatar
小湉湉 已提交
290
    # speech edit
O
oyjxer 已提交
291 292 293
    else:
        for key in list(word2phns.keys())[::-1]:
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
294 295 296
            if wrd == 'sp':
                sp_count += 1
                new_phns_right = ['sp'] + new_phns_right
P
pfZhu 已提交
297
            else:
小湉湉's avatar
小湉湉 已提交
298 299
                idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
                                                   - sp_count))
小湉湉's avatar
小湉湉 已提交
300
                if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
301
                    right_idx -= len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
302
                    new_phns_right = word2phns[key].split() + new_phns_right
P
pfZhu 已提交
303
                else:
小湉湉's avatar
小湉湉 已提交
304 305 306 307 308 309 310 311 312
                    span_to_repl[1] = len(old_phns) - len(new_phns_right)
                    new_phns_mid = new_phns[left_idx:right_idx]
                    span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
                    if len(new_phns_mid) == 0:
                        span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
                        span_to_add[0] = max(0, span_to_add[0] - 1)
                        span_to_repl[0] = max(0, span_to_repl[0] - 1)
                        span_to_repl[1] = min(span_to_repl[1] + 1,
                                              len(old_phns))
O
oyjxer 已提交
313
                    break
小湉湉's avatar
小湉湉 已提交
314
    new_phns = new_phns_left + new_phns_mid + new_phns_right
小湉湉's avatar
小湉湉 已提交
315 316 317 318 319 320
    '''
    For that reason cover should not be given.
    For that reason cover is impossible to be given.
    span_to_repl: [17, 23] "should not"
    span_to_add: [17, 30]  "is impossible to"
    '''
小湉湉's avatar
小湉湉 已提交
321
    return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
P
pfZhu 已提交
322 323


小湉湉's avatar
小湉湉 已提交
324 325 326
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def duration_adjust_factor(orig_dur: List[int],
小湉湉's avatar
小湉湉 已提交
327 328
                           pred_dur: List[int],
                           phns: List[str]):
P
pfZhu 已提交
329 330
    length = 0
    factor_list = []
小湉湉's avatar
小湉湉 已提交
331
    for orig, pred, phn in zip(orig_dur, pred_dur, phns):
小湉湉's avatar
小湉湉 已提交
332
        if pred == 0 or phn == 'sp':
P
pfZhu 已提交
333 334
            continue
        else:
小湉湉's avatar
小湉湉 已提交
335
            factor_list.append(orig / pred)
P
pfZhu 已提交
336 337
    factor_list = np.array(factor_list)
    factor_list.sort()
小湉湉's avatar
小湉湉 已提交
338
    if len(factor_list) < 5:
P
pfZhu 已提交
339 340
        return 1
    length = 2
小湉湉's avatar
小湉湉 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    avg = np.average(factor_list[length:-length])
    return avg


def prep_feats_with_dur(wav_path: str,
                        mlm_model: nn.Layer,
                        source_lang: str="English",
                        target_lang: str="English",
                        old_str: str="",
                        new_str: str="",
                        mask_reconstruct: bool=False,
                        duration_adjust: bool=True,
                        start_end_sp: bool=False,
                        fs: int=24000,
                        hop_length: int=300):
    '''
    Returns:
        np.ndarray: new wav, replace the part to be edited in original wav with 0
        List[str]: new phones
        List[float]: mfa start of new wav
        List[float]: mfa end of new wav
        List[int]: masked mel boundary of original wav
        List[int]: masked mel boundary of new wav
    '''
    wav_org, _ = librosa.load(wav_path, sr=fs)
小湉湉's avatar
小湉湉 已提交
366

小湉湉's avatar
小湉湉 已提交
367 368 369 370 371 372
    mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        source_lang=source_lang,
        target_lang=target_lang)
P
pfZhu 已提交
373 374

    if start_end_sp:
小湉湉's avatar
小湉湉 已提交
375 376
        if new_phns[-1] != 'sp':
            new_phns = new_phns + ['sp']
小湉湉's avatar
小湉湉 已提交
377 378 379
    # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
    if target_lang == "english" or target_lang == "chinese":
        old_durs = evaluate_durations(old_phns, target_lang=source_lang)
P
pfZhu 已提交
380
    else:
小湉湉's avatar
小湉湉 已提交
381 382
        assert target_lang == "chinese" or target_lang == "english", \
            "calculate duration_predict is not support for this language..."
P
pfZhu 已提交
383

小湉湉's avatar
小湉湉 已提交
384
    orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
P
pfZhu 已提交
385 386
    if '[MASK]' in new_str:
        new_phns = old_phns
小湉湉's avatar
小湉湉 已提交
387
        span_to_add = span_to_repl
小湉湉's avatar
小湉湉 已提交
388
        d_factor_left = duration_adjust_factor(
小湉湉's avatar
小湉湉 已提交
389 390 391
            orig_dur=orig_old_durs[:span_to_repl[0]],
            pred_dur=old_durs[:span_to_repl[0]],
            phns=old_phns[:span_to_repl[0]])
小湉湉's avatar
小湉湉 已提交
392
        d_factor_right = duration_adjust_factor(
小湉湉's avatar
小湉湉 已提交
393 394 395
            orig_dur=orig_old_durs[span_to_repl[1]:],
            pred_dur=old_durs[span_to_repl[1]:],
            phns=old_phns[span_to_repl[1]:])
小湉湉's avatar
小湉湉 已提交
396
        d_factor = (d_factor_left + d_factor_right) / 2
小湉湉's avatar
小湉湉 已提交
397
        new_durs_adjusted = [d_factor * i for i in old_durs]
P
pfZhu 已提交
398 399
    else:
        if duration_adjust:
小湉湉's avatar
小湉湉 已提交
400 401 402
            d_factor = duration_adjust_factor(
                orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
            print("d_factor:", d_factor)
小湉湉's avatar
小湉湉 已提交
403
            d_factor = d_factor * 1.25
P
pfZhu 已提交
404 405 406
        else:
            d_factor = 1

小湉湉's avatar
小湉湉 已提交
407 408 409 410 411 412 413 414 415 416 417
        if target_lang == "english" or target_lang == "chinese":
            new_durs = evaluate_durations(new_phns, target_lang=target_lang)
        else:
            assert target_lang == "chinese" or target_lang == "english", \
                "calculate duration_predict is not support for this language..."

        new_durs_adjusted = [d_factor * i for i in new_durs]

    new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
    old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
    dur_offset = new_span_dur_sum - old_span_dur_sum
小湉湉's avatar
小湉湉 已提交
418 419
    new_mfa_start = mfa_start[:span_to_repl[0]]
    new_mfa_end = mfa_end[:span_to_repl[0]]
小湉湉's avatar
小湉湉 已提交
420
    for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
小湉湉's avatar
小湉湉 已提交
421
        if len(new_mfa_end) == 0:
P
pfZhu 已提交
422 423 424 425
            new_mfa_start.append(0)
            new_mfa_end.append(i)
        else:
            new_mfa_start.append(new_mfa_end[-1])
小湉湉's avatar
小湉湉 已提交
426
            new_mfa_end.append(new_mfa_end[-1] + i)
小湉湉's avatar
小湉湉 已提交
427 428
    new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
    new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
小湉湉's avatar
小湉湉 已提交
429

小湉湉's avatar
小湉湉 已提交
430 431
    # 3. get new wav
    # 在原始句子后拼接
小湉湉's avatar
小湉湉 已提交
432 433 434
    if span_to_repl[0] >= len(mfa_start):
        left_idx = len(wav_org)
        right_idx = left_idx
小湉湉's avatar
小湉湉 已提交
435
    # 在原始句子中间替换
P
pfZhu 已提交
436
    else:
小湉湉's avatar
小湉湉 已提交
437 438
        left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
        right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
小湉湉's avatar
小湉湉 已提交
439 440 441 442 443
    blank_wav = np.zeros(
        (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
    # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
    new_wav = np.concatenate(
        [wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
P
pfZhu 已提交
444 445

    # 4. get old and new mel span to be mask
小湉湉's avatar
小湉湉 已提交
446
    # [92, 92]
小湉湉's avatar
小湉湉 已提交
447 448 449 450 451 452 453

    old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
        mfa_start=mfa_start,
        mfa_end=mfa_end,
        fs=fs,
        hop_length=hop_length,
        span_to_repl=span_to_repl)
小湉湉's avatar
小湉湉 已提交
454
    # [92, 174]
小湉湉's avatar
小湉湉 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
    new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
        mfa_start=new_mfa_start,
        mfa_end=new_mfa_end,
        fs=fs,
        hop_length=hop_length,
        span_to_repl=span_to_add)

    # old_span_bdy, new_span_bdy 是帧级别的范围
    return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy


def prep_feats(mlm_model: nn.Layer,
               wav_path: str,
               source_lang: str="english",
               target_lang: str="english",
               old_str: str="",
               new_str: str="",
               duration_adjust: bool=True,
               start_end_sp: bool=False,
               mask_reconstruct: bool=False,
               fs: int=24000,
               hop_length: int=300,
               token_list: List[str]=[]):
    wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
小湉湉's avatar
小湉湉 已提交
480 481 482 483 484 485
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        old_str=old_str,
        new_str=new_str,
        wav_path=wav_path,
小湉湉's avatar
小湉湉 已提交
486 487 488
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        mask_reconstruct=mask_reconstruct,
小湉湉's avatar
小湉湉 已提交
489 490
        fs=fs,
        hop_length=hop_length)
小湉湉's avatar
小湉湉 已提交
491

小湉湉's avatar
小湉湉 已提交
492 493 494
    token_to_id = {item: i for i, item in enumerate(token_list)}
    text = np.array(
        list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
小湉湉's avatar
小湉湉 已提交
495
    span_bdy = np.array(new_span_bdy)
小湉湉's avatar
小湉湉 已提交
496

小湉湉's avatar
小湉湉 已提交
497
    batch = [('1', {
小湉湉's avatar
小湉湉 已提交
498 499 500
        "speech": wav,
        "align_start": mfa_start,
        "align_end": mfa_end,
小湉湉's avatar
小湉湉 已提交
501
        "text": text,
小湉湉's avatar
小湉湉 已提交
502
        "span_bdy": span_bdy
小湉湉's avatar
小湉湉 已提交
503 504
    })]

小湉湉's avatar
小湉湉 已提交
505
    return batch, old_span_bdy, new_span_bdy
P
pfZhu 已提交
506 507


小湉湉's avatar
小湉湉 已提交
508
def decode_with_model(mlm_model: nn.Layer,
小湉湉's avatar
小湉湉 已提交
509
                      collate_fn,
小湉湉's avatar
小湉湉 已提交
510 511 512 513 514 515 516 517
                      wav_path: str,
                      source_lang: str="english",
                      target_lang: str="english",
                      old_str: str="",
                      new_str: str="",
                      use_teacher_forcing: bool=False,
                      duration_adjust: bool=True,
                      start_end_sp: bool=False,
小湉湉's avatar
小湉湉 已提交
518 519 520 521
                      fs: int=24000,
                      hop_length: int=300,
                      token_list: List[str]=[]):
    batch, old_span_bdy, new_span_bdy = prep_feats(
小湉湉's avatar
小湉湉 已提交
522 523 524 525 526 527
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
528 529
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
小湉湉's avatar
小湉湉 已提交
530 531 532
        fs=fs,
        hop_length=hop_length,
        token_list=token_list)
小湉湉's avatar
小湉湉 已提交
533

O
oyjxer 已提交
534 535
    feats = collate_fn(batch)[1]

小湉湉's avatar
小湉湉 已提交
536 537
    if 'text_masked_pos' in feats.keys():
        feats.pop('text_masked_pos')
小湉湉's avatar
小湉湉 已提交
538 539 540 541 542 543 544 545 546 547 548 549

    output = mlm_model.inference(
        text=feats['text'],
        speech=feats['speech'],
        masked_pos=feats['masked_pos'],
        speech_mask=feats['speech_mask'],
        text_mask=feats['text_mask'],
        speech_seg_pos=feats['speech_seg_pos'],
        text_seg_pos=feats['text_seg_pos'],
        span_bdy=new_span_bdy,
        use_teacher_forcing=use_teacher_forcing)

小湉湉's avatar
小湉湉 已提交
550 551
    # 拼接音频
    output_feat = paddle.concat(x=output, axis=0)
小湉湉's avatar
小湉湉 已提交
552 553 554 555 556 557
    wav_org, _ = librosa.load(wav_path, sr=fs)
    return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length


def get_mlm_output(wav_path: str,
                   model_name: str="paddle_checkpoint_en",
小湉湉's avatar
小湉湉 已提交
558 559 560 561 562 563 564
                   source_lang: str="english",
                   target_lang: str="english",
                   old_str: str="",
                   new_str: str="",
                   use_teacher_forcing: bool=False,
                   duration_adjust: bool=True,
                   start_end_sp: bool=False):
小湉湉's avatar
小湉湉 已提交
565
    mlm_model, train_conf = load_model(model_name)
P
pfZhu 已提交
566
    mlm_model.eval()
小湉湉's avatar
小湉湉 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579

    collate_fn = build_collate_fn(
        sr=train_conf.feats_extract_conf['fs'],
        n_fft=train_conf.feats_extract_conf['n_fft'],
        hop_length=train_conf.feats_extract_conf['hop_length'],
        win_length=train_conf.feats_extract_conf['win_length'],
        n_mels=train_conf.feats_extract_conf['n_mels'],
        fmin=train_conf.feats_extract_conf['fmin'],
        fmax=train_conf.feats_extract_conf['fmax'],
        mlm_prob=train_conf['mlm_prob'],
        mean_phn_span=train_conf['mean_phn_span'],
        train=False,
        seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
P
pfZhu 已提交
580

小湉湉's avatar
小湉湉 已提交
581
    return decode_with_model(
小湉湉's avatar
小湉湉 已提交
582 583 584 585 586 587 588
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        collate_fn=collate_fn,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
589 590 591
        use_teacher_forcing=use_teacher_forcing,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
小湉湉's avatar
小湉湉 已提交
592 593 594
        fs=train_conf.feats_extract_conf['fs'],
        hop_length=train_conf.feats_extract_conf['hop_length'],
        token_list=train_conf.token_list)
小湉湉's avatar
小湉湉 已提交
595 596


小湉湉's avatar
小湉湉 已提交
597 598 599 600
def evaluate(uid: str,
             source_lang: str="english",
             target_lang: str="english",
             use_pt_vocoder: bool=False,
小湉湉's avatar
小湉湉 已提交
601 602
             prefix: os.PathLike="./prompt/dev/",
             model_name: str="paddle_checkpoint_en",
小湉湉's avatar
小湉湉 已提交
603 604 605
             new_str: str="",
             prompt_decoding: bool=False,
             task_name: str=None):
P
pfZhu 已提交
606

小湉湉's avatar
小湉湉 已提交
607 608
    # get origin text and path of origin wav
    old_str, wav_path = read_data(uid=uid, prefix=prefix)
小湉湉's avatar
小湉湉 已提交
609

O
oyjxer 已提交
610 611 612
    if task_name == 'edit':
        new_str = new_str
    elif task_name == 'synthesize':
小湉湉's avatar
小湉湉 已提交
613
        new_str = old_str + new_str
O
oyjxer 已提交
614
    else:
小湉湉's avatar
小湉湉 已提交
615
        new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
小湉湉's avatar
小湉湉 已提交
616

P
pfZhu 已提交
617
    print('new_str is ', new_str)
小湉湉's avatar
小湉湉 已提交
618 619

    results_dict, old_span = plot_mel_and_vocode_wav(
小湉湉's avatar
小湉湉 已提交
620 621 622 623 624 625
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
626
        use_pt_vocoder=use_pt_vocoder)
P
pfZhu 已提交
627 628
    return results_dict

小湉湉's avatar
小湉湉 已提交
629

P
pfZhu 已提交
630
if __name__ == "__main__":
小湉湉's avatar
小湉湉 已提交
631
    # parse config and args
P
pfZhu 已提交
632
    args = parse_args()
小湉湉's avatar
小湉湉 已提交
633

小湉湉's avatar
小湉湉 已提交
634 635 636 637 638 639 640
    data_dict = evaluate(
        uid=args.uid,
        source_lang=args.source_lang,
        target_lang=args.target_lang,
        use_pt_vocoder=args.use_pt_vocoder,
        prefix=args.prefix,
        model_name=args.model_name,
P
pfZhu 已提交
641 642
        new_str=args.new_str,
        task_name=args.task_name)
小湉湉's avatar
小湉湉 已提交
643
    sf.write(args.output_name, data_dict['output'], samplerate=24000)
O
oyjxer 已提交
644
    print("finished...")