inference.py 21.8 KB
Newer Older
P
pfZhu 已提交
1
#!/usr/bin/env python3
小湉湉's avatar
小湉湉 已提交
2 3 4
import os
import random
from pathlib import Path
P
pfZhu 已提交
5 6
from typing import Dict
from typing import List
小湉湉's avatar
小湉湉 已提交
7 8 9

import librosa
import numpy as np
O
oyjxer 已提交
10
import paddle
小湉湉's avatar
小湉湉 已提交
11
import soundfile as sf
O
oyjxer 已提交
12
import torch
小湉湉's avatar
小湉湉 已提交
13
from paddle import nn
小湉湉's avatar
小湉湉 已提交
14
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
小湉湉's avatar
小湉湉 已提交
15 16 17 18 19 20 21 22 23

from align import alignment
from align import alignment_zh
from align import words2phns
from align import words2phns_zh
from collect_fn import build_collate_fn
from mlm import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2col_text
小湉湉's avatar
小湉湉 已提交
24 25 26 27 28
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import eval_durs
from utils import get_voc_out
from utils import is_chinese
小湉湉's avatar
小湉湉 已提交
29

P
pfZhu 已提交
30 31 32
random.seed(0)
np.random.seed(0)

O
oyjxer 已提交
33

小湉湉's avatar
小湉湉 已提交
34 35 36 37 38 39 40 41
def get_wav(wav_path: str,
            source_lang: str='english',
            target_lang: str='english',
            model_name: str="paddle_checkpoint_en",
            old_str: str="",
            new_str: str="",
            use_pt_vocoder: bool=False,
            non_autoreg: bool=True):
小湉湉's avatar
小湉湉 已提交
42
    wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
小湉湉's avatar
小湉湉 已提交
43 44 45 46 47 48
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
49
        use_teacher_forcing=non_autoreg)
小湉湉's avatar
小湉湉 已提交
50

小湉湉's avatar
小湉湉 已提交
51
    masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
小湉湉's avatar
小湉湉 已提交
52

小湉湉's avatar
小湉湉 已提交
53 54 55 56 57
    if target_lang == 'english' and use_pt_vocoder:
        masked_feat = masked_feat.cpu().numpy()
        masked_feat = torch.tensor(masked_feat, dtype=torch.float)
        vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
        alt_wav = vocoder(masked_feat).cpu().numpy()
P
pfZhu 已提交
58

小湉湉's avatar
小湉湉 已提交
59 60
    else:
        alt_wav = get_voc_out(masked_feat)
小湉湉's avatar
小湉湉 已提交
61

小湉湉's avatar
小湉湉 已提交
62
    old_time_bdy = [hop_length * x for x in old_span_bdy]
P
pfZhu 已提交
63

小湉湉's avatar
小湉湉 已提交
64 65
    wav_replaced = np.concatenate(
        [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
P
pfZhu 已提交
66

小湉湉's avatar
小湉湉 已提交
67
    data_dict = {"origin": wav_org, "output": wav_replaced}
P
pfZhu 已提交
68

小湉湉's avatar
小湉湉 已提交
69
    return data_dict
P
pfZhu 已提交
70

O
oyjxer 已提交
71

小湉湉's avatar
小湉湉 已提交
72
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
O
oyjxer 已提交
73 74 75
    vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
    vocoder_file = download_pretrained_model(vocoder_tag)
    vocoder_config = Path(vocoder_file).parent / "config.yml"
小湉湉's avatar
小湉湉 已提交
76
    vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
O
oyjxer 已提交
77 78
    return vocoder

小湉湉's avatar
小湉湉 已提交
79

小湉湉's avatar
小湉湉 已提交
80
def load_model(model_name: str="paddle_checkpoint_en"):
小湉湉's avatar
小湉湉 已提交
81
    config_path = './pretrained_model/{}/config.yaml'.format(model_name)
P
pfZhu 已提交
82
    model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
小湉湉's avatar
小湉湉 已提交
83
    mlm_model, conf = build_model_from_file(
小湉湉's avatar
小湉湉 已提交
84
        config_file=config_path, model_file=model_path)
小湉湉's avatar
小湉湉 已提交
85
    return mlm_model, conf
P
pfZhu 已提交
86 87


小湉湉's avatar
小湉湉 已提交
88 89 90 91 92 93 94
def read_data(uid: str, prefix: os.PathLike):
    # 获取 uid 对应的文本
    mfa_text = read_2col_text(prefix + '/text')[uid]
    # 获取 uid 对应的音频路径
    mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
    if not os.path.isabs(mfa_wav_path):
        mfa_wav_path = prefix + mfa_wav_path
P
pfZhu 已提交
95
    return mfa_text, mfa_wav_path
小湉湉's avatar
小湉湉 已提交
96 97


小湉湉's avatar
小湉湉 已提交
98
def get_align_data(uid: str, prefix: os.PathLike):
小湉湉's avatar
小湉湉 已提交
99
    mfa_path = prefix + "mfa_"
小湉湉's avatar
小湉湉 已提交
100
    mfa_text = read_2col_text(mfa_path + 'text')[uid]
小湉湉's avatar
小湉湉 已提交
101 102 103 104
    mfa_start = load_num_sequence_text(
        mfa_path + 'start', loader_type='text_float')[uid]
    mfa_end = load_num_sequence_text(
        mfa_path + 'end', loader_type='text_float')[uid]
小湉湉's avatar
小湉湉 已提交
105
    mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
P
pfZhu 已提交
106 107 108
    return mfa_text, mfa_start, mfa_end, mfa_wav_path


小湉湉's avatar
小湉湉 已提交
109
# 获取需要被 mask 的 mel 帧的范围
小湉湉's avatar
小湉湉 已提交
110 111 112 113 114
def get_masked_mel_bdy(mfa_start: List[float],
                       mfa_end: List[float],
                       fs: int,
                       hop_length: int,
                       span_to_repl: List[List[int]]):
小湉湉's avatar
小湉湉 已提交
115 116 117 118
    align_start = np.array(mfa_start)
    align_end = np.array(mfa_end)
    align_start = np.floor(fs * align_start / hop_length).astype('int')
    align_end = np.floor(fs * align_end / hop_length).astype('int')
小湉湉's avatar
小湉湉 已提交
119
    if span_to_repl[0] >= len(mfa_start):
小湉湉's avatar
小湉湉 已提交
120
        span_bdy = [align_end[-1], align_end[-1]]
P
pfZhu 已提交
121
    else:
小湉湉's avatar
小湉湉 已提交
122
        span_bdy = [
小湉湉's avatar
小湉湉 已提交
123
            align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
小湉湉's avatar
小湉湉 已提交
124
        ]
小湉湉's avatar
小湉湉 已提交
125
    return span_bdy, align_start, align_end
P
pfZhu 已提交
126 127


小湉湉's avatar
小湉湉 已提交
128
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
O
oyjxer 已提交
129
    dic = {}
小湉湉's avatar
小湉湉 已提交
130 131
    keys_to_del = []
    exist_idx = []
小湉湉's avatar
小湉湉 已提交
132 133
    sp_count = 0
    add_sp_count = 0
O
oyjxer 已提交
134 135 136
    for key in word2phns.keys():
        idx, wrd = key.split('_')
        if wrd == 'sp':
小湉湉's avatar
小湉湉 已提交
137
            sp_count += 1
小湉湉's avatar
小湉湉 已提交
138
            exist_idx.append(int(idx))
P
pfZhu 已提交
139
        else:
小湉湉's avatar
小湉湉 已提交
140
            keys_to_del.append(key)
小湉湉's avatar
小湉湉 已提交
141

小湉湉's avatar
小湉湉 已提交
142
    for key in keys_to_del:
O
oyjxer 已提交
143 144 145 146
        del word2phns[key]

    cur_id = 0
    for key in tp_word2phns.keys():
小湉湉's avatar
小湉湉 已提交
147
        if cur_id in exist_idx:
小湉湉's avatar
小湉湉 已提交
148 149 150
            dic[str(cur_id) + "_sp"] = 'sp'
            cur_id += 1
            add_sp_count += 1
O
oyjxer 已提交
151
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
152
        dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
O
oyjxer 已提交
153
        cur_id += 1
小湉湉's avatar
小湉湉 已提交
154

O
oyjxer 已提交
155
    if add_sp_count + 1 == sp_count:
小湉湉's avatar
小湉湉 已提交
156 157 158
        dic[str(cur_id) + "_sp"] = 'sp'
        add_sp_count += 1

O
oyjxer 已提交
159 160
    assert add_sp_count == sp_count, "sp are not added in dic"
    return dic
P
pfZhu 已提交
161 162


小湉湉's avatar
小湉湉 已提交
163 164 165 166
def get_max_idx(dic):
    return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]


小湉湉's avatar
小湉湉 已提交
167 168 169 170 171
def get_phns_and_spans(wav_path: str,
                       old_str: str="",
                       new_str: str="",
                       source_lang: str="english",
                       target_lang: str="english"):
小湉湉's avatar
小湉湉 已提交
172
    is_append = (old_str == new_str[:len(old_str)])
P
pfZhu 已提交
173
    old_phns, mfa_start, mfa_end = [], [], []
小湉湉's avatar
小湉湉 已提交
174
    # source
小湉湉's avatar
小湉湉 已提交
175
    if source_lang == "english":
小湉湉's avatar
小湉湉 已提交
176
        intervals, word2phns = alignment(wav_path, old_str)
小湉湉's avatar
小湉湉 已提交
177
    elif source_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
178
        intervals, word2phns = alignment_zh(wav_path, old_str)
小湉湉's avatar
小湉湉 已提交
179 180 181
        _, tp_word2phns = words2phns_zh(old_str)

        for key, value in tp_word2phns.items():
O
oyjxer 已提交
182 183
            idx, wrd = key.split('_')
            cur_val = " ".join(value)
小湉湉's avatar
小湉湉 已提交
184
            tp_word2phns[key] = cur_val
P
pfZhu 已提交
185

O
oyjxer 已提交
186 187
        word2phns = recover_dict(word2phns, tp_word2phns)
    else:
小湉湉's avatar
小湉湉 已提交
188 189
        assert source_lang == "chinese" or source_lang == "english", \
            "source_lang is wrong..."
P
pfZhu 已提交
190

小湉湉's avatar
小湉湉 已提交
191 192
    for item in intervals:
        old_phns.append(item[0])
O
oyjxer 已提交
193 194
        mfa_start.append(float(item[1]))
        mfa_end.append(float(item[2]))
小湉湉's avatar
小湉湉 已提交
195 196 197
    # target
    if is_append and (source_lang != target_lang):
        cross_lingual_clone = True
P
pfZhu 已提交
198
    else:
小湉湉's avatar
小湉湉 已提交
199
        cross_lingual_clone = False
P
pfZhu 已提交
200

小湉湉's avatar
小湉湉 已提交
201 202 203
    if cross_lingual_clone:
        str_origin = new_str[:len(old_str)]
        str_append = new_str[len(old_str):]
P
pfZhu 已提交
204

小湉湉's avatar
小湉湉 已提交
205
        if target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
206 207
            phns_origin, origin_word2phns = words2phns(str_origin)
            phns_append, append_word2phns_tmp = words2phns_zh(str_append)
P
pfZhu 已提交
208

小湉湉's avatar
小湉湉 已提交
209 210
        elif target_lang == "english":
            # 原始句子
小湉湉's avatar
小湉湉 已提交
211 212 213
            phns_origin, origin_word2phns = words2phns_zh(str_origin)
            # clone 句子 
            phns_append, append_word2phns_tmp = words2phns(str_append)
P
pfZhu 已提交
214
        else:
小湉湉's avatar
小湉湉 已提交
215 216
            assert target_lang == "chinese" or target_lang == "english", \
                "cloning is not support for this language, please check it."
小湉湉's avatar
小湉湉 已提交
217

小湉湉's avatar
小湉湉 已提交
218
        new_phns = phns_origin + phns_append
O
oyjxer 已提交
219

小湉湉's avatar
小湉湉 已提交
220 221 222
        append_word2phns = {}
        length = len(origin_word2phns)
        for key, value in append_word2phns_tmp.items():
O
oyjxer 已提交
223
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
224 225 226
            append_word2phns[str(int(idx) + length) + '_' + wrd] = value
        new_word2phns = origin_word2phns.copy()
        new_word2phns.update(append_word2phns)
小湉湉's avatar
小湉湉 已提交
227 228

    else:
小湉湉's avatar
小湉湉 已提交
229
        if source_lang == target_lang and target_lang == "english":
O
oyjxer 已提交
230
            new_phns, new_word2phns = words2phns(new_str)
小湉湉's avatar
小湉湉 已提交
231
        elif source_lang == target_lang and target_lang == "chinese":
O
oyjxer 已提交
232 233
            new_phns, new_word2phns = words2phns_zh(new_str)
        else:
小湉湉's avatar
小湉湉 已提交
234 235
            assert source_lang == target_lang, \
                "source language is not same with target language..."
小湉湉's avatar
小湉湉 已提交
236

小湉湉's avatar
小湉湉 已提交
237 238 239
    span_to_repl = [0, len(old_phns) - 1]
    span_to_add = [0, len(new_phns) - 1]
    left_idx = 0
O
oyjxer 已提交
240 241 242 243 244
    new_phns_left = []
    sp_count = 0
    # find the left different index
    for key in word2phns.keys():
        idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
245 246
        if wrd == 'sp':
            sp_count += 1
O
oyjxer 已提交
247 248 249
            new_phns_left.append('sp')
        else:
            idx = str(int(idx) - sp_count)
小湉湉's avatar
小湉湉 已提交
250
            if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
251
                left_idx += len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
252
                new_phns_left.extend(word2phns[key].split())
P
pfZhu 已提交
253
            else:
小湉湉's avatar
小湉湉 已提交
254 255
                span_to_repl[0] = len(new_phns_left)
                span_to_add[0] = len(new_phns_left)
O
oyjxer 已提交
256
                break
小湉湉's avatar
小湉湉 已提交
257

O
oyjxer 已提交
258
    # reverse word2phns and new_word2phns
小湉湉's avatar
小湉湉 已提交
259
    right_idx = 0
O
oyjxer 已提交
260 261
    new_phns_right = []
    sp_count = 0
小湉湉's avatar
小湉湉 已提交
262 263
    word2phns_max_idx = get_max_idx(word2phns)
    new_word2phns_max_idx = get_max_idx(new_word2phns)
小湉湉's avatar
小湉湉 已提交
264
    new_phns_mid = []
小湉湉's avatar
小湉湉 已提交
265
    if is_append:
P
pfZhu 已提交
266
        new_phns_right = []
小湉湉's avatar
小湉湉 已提交
267 268 269 270 271
        new_phns_mid = new_phns[left_idx:]
        span_to_repl[0] = len(new_phns_left)
        span_to_add[0] = len(new_phns_left)
        span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
        span_to_repl[1] = len(old_phns) - len(new_phns_right)
小湉湉's avatar
小湉湉 已提交
272
    # speech edit
O
oyjxer 已提交
273 274 275
    else:
        for key in list(word2phns.keys())[::-1]:
            idx, wrd = key.split('_')
小湉湉's avatar
小湉湉 已提交
276 277 278
            if wrd == 'sp':
                sp_count += 1
                new_phns_right = ['sp'] + new_phns_right
P
pfZhu 已提交
279
            else:
小湉湉's avatar
小湉湉 已提交
280 281
                idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
                                                   - sp_count))
小湉湉's avatar
小湉湉 已提交
282
                if idx + '_' + wrd in new_word2phns:
小湉湉's avatar
小湉湉 已提交
283
                    right_idx -= len(new_word2phns[idx + '_' + wrd])
O
oyjxer 已提交
284
                    new_phns_right = word2phns[key].split() + new_phns_right
P
pfZhu 已提交
285
                else:
小湉湉's avatar
小湉湉 已提交
286 287 288 289 290 291 292 293 294
                    span_to_repl[1] = len(old_phns) - len(new_phns_right)
                    new_phns_mid = new_phns[left_idx:right_idx]
                    span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
                    if len(new_phns_mid) == 0:
                        span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
                        span_to_add[0] = max(0, span_to_add[0] - 1)
                        span_to_repl[0] = max(0, span_to_repl[0] - 1)
                        span_to_repl[1] = min(span_to_repl[1] + 1,
                                              len(old_phns))
O
oyjxer 已提交
295
                    break
小湉湉's avatar
小湉湉 已提交
296
    new_phns = new_phns_left + new_phns_mid + new_phns_right
小湉湉's avatar
小湉湉 已提交
297 298 299 300 301 302
    '''
    For that reason cover should not be given.
    For that reason cover is impossible to be given.
    span_to_repl: [17, 23] "should not"
    span_to_add: [17, 30]  "is impossible to"
    '''
小湉湉's avatar
小湉湉 已提交
303
    return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
P
pfZhu 已提交
304 305


小湉湉's avatar
小湉湉 已提交
306 307
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
小湉湉's avatar
小湉湉 已提交
308 309 310
def get_dur_adj_factor(orig_dur: List[int],
                       pred_dur: List[int],
                       phns: List[str]):
P
pfZhu 已提交
311 312
    length = 0
    factor_list = []
小湉湉's avatar
小湉湉 已提交
313
    for orig, pred, phn in zip(orig_dur, pred_dur, phns):
小湉湉's avatar
小湉湉 已提交
314
        if pred == 0 or phn == 'sp':
P
pfZhu 已提交
315 316
            continue
        else:
小湉湉's avatar
小湉湉 已提交
317
            factor_list.append(orig / pred)
P
pfZhu 已提交
318 319
    factor_list = np.array(factor_list)
    factor_list.sort()
小湉湉's avatar
小湉湉 已提交
320
    if len(factor_list) < 5:
P
pfZhu 已提交
321 322
        return 1
    length = 2
小湉湉's avatar
小湉湉 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    avg = np.average(factor_list[length:-length])
    return avg


def prep_feats_with_dur(wav_path: str,
                        mlm_model: nn.Layer,
                        source_lang: str="English",
                        target_lang: str="English",
                        old_str: str="",
                        new_str: str="",
                        mask_reconstruct: bool=False,
                        duration_adjust: bool=True,
                        start_end_sp: bool=False,
                        fs: int=24000,
                        hop_length: int=300):
    '''
    Returns:
        np.ndarray: new wav, replace the part to be edited in original wav with 0
        List[str]: new phones
        List[float]: mfa start of new wav
        List[float]: mfa end of new wav
        List[int]: masked mel boundary of original wav
        List[int]: masked mel boundary of new wav
    '''
    wav_org, _ = librosa.load(wav_path, sr=fs)
小湉湉's avatar
小湉湉 已提交
348

小湉湉's avatar
小湉湉 已提交
349 350 351 352 353 354
    mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
        source_lang=source_lang,
        target_lang=target_lang)
P
pfZhu 已提交
355 356

    if start_end_sp:
小湉湉's avatar
小湉湉 已提交
357 358
        if new_phns[-1] != 'sp':
            new_phns = new_phns + ['sp']
小湉湉's avatar
小湉湉 已提交
359 360
    # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
    if target_lang == "english" or target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
361
        old_durs = eval_durs(old_phns, target_lang=source_lang)
P
pfZhu 已提交
362
    else:
小湉湉's avatar
小湉湉 已提交
363 364
        assert target_lang == "chinese" or target_lang == "english", \
            "calculate duration_predict is not support for this language..."
P
pfZhu 已提交
365

小湉湉's avatar
小湉湉 已提交
366
    orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
P
pfZhu 已提交
367 368
    if '[MASK]' in new_str:
        new_phns = old_phns
小湉湉's avatar
小湉湉 已提交
369
        span_to_add = span_to_repl
小湉湉's avatar
小湉湉 已提交
370
        d_factor_left = get_dur_adj_factor(
小湉湉's avatar
小湉湉 已提交
371 372 373
            orig_dur=orig_old_durs[:span_to_repl[0]],
            pred_dur=old_durs[:span_to_repl[0]],
            phns=old_phns[:span_to_repl[0]])
小湉湉's avatar
小湉湉 已提交
374
        d_factor_right = get_dur_adj_factor(
小湉湉's avatar
小湉湉 已提交
375 376 377
            orig_dur=orig_old_durs[span_to_repl[1]:],
            pred_dur=old_durs[span_to_repl[1]:],
            phns=old_phns[span_to_repl[1]:])
小湉湉's avatar
小湉湉 已提交
378
        d_factor = (d_factor_left + d_factor_right) / 2
小湉湉's avatar
小湉湉 已提交
379
        new_durs_adjusted = [d_factor * i for i in old_durs]
P
pfZhu 已提交
380 381
    else:
        if duration_adjust:
小湉湉's avatar
小湉湉 已提交
382
            d_factor = get_dur_adj_factor(
小湉湉's avatar
小湉湉 已提交
383
                orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
小湉湉's avatar
小湉湉 已提交
384
            d_factor = d_factor * 1.25
P
pfZhu 已提交
385 386 387
        else:
            d_factor = 1

小湉湉's avatar
小湉湉 已提交
388
        if target_lang == "english" or target_lang == "chinese":
小湉湉's avatar
小湉湉 已提交
389
            new_durs = eval_durs(new_phns, target_lang=target_lang)
小湉湉's avatar
小湉湉 已提交
390 391 392 393 394 395 396 397 398
        else:
            assert target_lang == "chinese" or target_lang == "english", \
                "calculate duration_predict is not support for this language..."

        new_durs_adjusted = [d_factor * i for i in new_durs]

    new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
    old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
    dur_offset = new_span_dur_sum - old_span_dur_sum
小湉湉's avatar
小湉湉 已提交
399 400
    new_mfa_start = mfa_start[:span_to_repl[0]]
    new_mfa_end = mfa_end[:span_to_repl[0]]
小湉湉's avatar
小湉湉 已提交
401
    for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
小湉湉's avatar
小湉湉 已提交
402
        if len(new_mfa_end) == 0:
P
pfZhu 已提交
403 404 405 406
            new_mfa_start.append(0)
            new_mfa_end.append(i)
        else:
            new_mfa_start.append(new_mfa_end[-1])
小湉湉's avatar
小湉湉 已提交
407
            new_mfa_end.append(new_mfa_end[-1] + i)
小湉湉's avatar
小湉湉 已提交
408 409
    new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
    new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
小湉湉's avatar
小湉湉 已提交
410

小湉湉's avatar
小湉湉 已提交
411 412
    # 3. get new wav
    # 在原始句子后拼接
小湉湉's avatar
小湉湉 已提交
413 414 415
    if span_to_repl[0] >= len(mfa_start):
        left_idx = len(wav_org)
        right_idx = left_idx
小湉湉's avatar
小湉湉 已提交
416
    # 在原始句子中间替换
P
pfZhu 已提交
417
    else:
小湉湉's avatar
小湉湉 已提交
418 419
        left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
        right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
小湉湉's avatar
小湉湉 已提交
420 421 422 423 424
    blank_wav = np.zeros(
        (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
    # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
    new_wav = np.concatenate(
        [wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
P
pfZhu 已提交
425 426

    # 4. get old and new mel span to be mask
小湉湉's avatar
小湉湉 已提交
427
    # [92, 92]
小湉湉's avatar
小湉湉 已提交
428 429 430 431 432 433 434

    old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
        mfa_start=mfa_start,
        mfa_end=mfa_end,
        fs=fs,
        hop_length=hop_length,
        span_to_repl=span_to_repl)
小湉湉's avatar
小湉湉 已提交
435
    # [92, 174]
小湉湉's avatar
小湉湉 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
    # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
    new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
        mfa_start=new_mfa_start,
        mfa_end=new_mfa_end,
        fs=fs,
        hop_length=hop_length,
        span_to_repl=span_to_add)

    # old_span_bdy, new_span_bdy 是帧级别的范围
    return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy


def prep_feats(mlm_model: nn.Layer,
               wav_path: str,
               source_lang: str="english",
               target_lang: str="english",
               old_str: str="",
               new_str: str="",
               duration_adjust: bool=True,
               start_end_sp: bool=False,
               mask_reconstruct: bool=False,
               fs: int=24000,
               hop_length: int=300,
               token_list: List[str]=[]):
    wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
小湉湉's avatar
小湉湉 已提交
461 462 463 464 465 466
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        old_str=old_str,
        new_str=new_str,
        wav_path=wav_path,
小湉湉's avatar
小湉湉 已提交
467 468 469
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
        mask_reconstruct=mask_reconstruct,
小湉湉's avatar
小湉湉 已提交
470 471
        fs=fs,
        hop_length=hop_length)
小湉湉's avatar
小湉湉 已提交
472

小湉湉's avatar
小湉湉 已提交
473 474 475
    token_to_id = {item: i for i, item in enumerate(token_list)}
    text = np.array(
        list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
小湉湉's avatar
小湉湉 已提交
476
    span_bdy = np.array(new_span_bdy)
小湉湉's avatar
小湉湉 已提交
477

小湉湉's avatar
小湉湉 已提交
478
    batch = [('1', {
小湉湉's avatar
小湉湉 已提交
479 480 481
        "speech": wav,
        "align_start": mfa_start,
        "align_end": mfa_end,
小湉湉's avatar
小湉湉 已提交
482
        "text": text,
小湉湉's avatar
小湉湉 已提交
483
        "span_bdy": span_bdy
小湉湉's avatar
小湉湉 已提交
484 485
    })]

小湉湉's avatar
小湉湉 已提交
486
    return batch, old_span_bdy, new_span_bdy
P
pfZhu 已提交
487 488


小湉湉's avatar
小湉湉 已提交
489
def decode_with_model(mlm_model: nn.Layer,
小湉湉's avatar
小湉湉 已提交
490
                      collate_fn,
小湉湉's avatar
小湉湉 已提交
491 492 493 494 495 496 497 498
                      wav_path: str,
                      source_lang: str="english",
                      target_lang: str="english",
                      old_str: str="",
                      new_str: str="",
                      use_teacher_forcing: bool=False,
                      duration_adjust: bool=True,
                      start_end_sp: bool=False,
小湉湉's avatar
小湉湉 已提交
499 500 501 502
                      fs: int=24000,
                      hop_length: int=300,
                      token_list: List[str]=[]):
    batch, old_span_bdy, new_span_bdy = prep_feats(
小湉湉's avatar
小湉湉 已提交
503 504 505 506 507 508
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
509 510
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
小湉湉's avatar
小湉湉 已提交
511 512 513
        fs=fs,
        hop_length=hop_length,
        token_list=token_list)
小湉湉's avatar
小湉湉 已提交
514

O
oyjxer 已提交
515 516
    feats = collate_fn(batch)[1]

小湉湉's avatar
小湉湉 已提交
517 518
    if 'text_masked_pos' in feats.keys():
        feats.pop('text_masked_pos')
小湉湉's avatar
小湉湉 已提交
519 520 521 522 523 524 525 526 527 528 529 530

    output = mlm_model.inference(
        text=feats['text'],
        speech=feats['speech'],
        masked_pos=feats['masked_pos'],
        speech_mask=feats['speech_mask'],
        text_mask=feats['text_mask'],
        speech_seg_pos=feats['speech_seg_pos'],
        text_seg_pos=feats['text_seg_pos'],
        span_bdy=new_span_bdy,
        use_teacher_forcing=use_teacher_forcing)

小湉湉's avatar
小湉湉 已提交
531 532
    # 拼接音频
    output_feat = paddle.concat(x=output, axis=0)
小湉湉's avatar
小湉湉 已提交
533 534 535 536 537 538
    wav_org, _ = librosa.load(wav_path, sr=fs)
    return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length


def get_mlm_output(wav_path: str,
                   model_name: str="paddle_checkpoint_en",
小湉湉's avatar
小湉湉 已提交
539 540 541 542 543 544 545
                   source_lang: str="english",
                   target_lang: str="english",
                   old_str: str="",
                   new_str: str="",
                   use_teacher_forcing: bool=False,
                   duration_adjust: bool=True,
                   start_end_sp: bool=False):
小湉湉's avatar
小湉湉 已提交
546
    mlm_model, train_conf = load_model(model_name)
P
pfZhu 已提交
547
    mlm_model.eval()
小湉湉's avatar
小湉湉 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560

    collate_fn = build_collate_fn(
        sr=train_conf.feats_extract_conf['fs'],
        n_fft=train_conf.feats_extract_conf['n_fft'],
        hop_length=train_conf.feats_extract_conf['hop_length'],
        win_length=train_conf.feats_extract_conf['win_length'],
        n_mels=train_conf.feats_extract_conf['n_mels'],
        fmin=train_conf.feats_extract_conf['fmin'],
        fmax=train_conf.feats_extract_conf['fmax'],
        mlm_prob=train_conf['mlm_prob'],
        mean_phn_span=train_conf['mean_phn_span'],
        train=False,
        seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
P
pfZhu 已提交
561

小湉湉's avatar
小湉湉 已提交
562
    return decode_with_model(
小湉湉's avatar
小湉湉 已提交
563 564 565 566 567 568 569
        source_lang=source_lang,
        target_lang=target_lang,
        mlm_model=mlm_model,
        collate_fn=collate_fn,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
570 571 572
        use_teacher_forcing=use_teacher_forcing,
        duration_adjust=duration_adjust,
        start_end_sp=start_end_sp,
小湉湉's avatar
小湉湉 已提交
573 574 575
        fs=train_conf.feats_extract_conf['fs'],
        hop_length=train_conf.feats_extract_conf['hop_length'],
        token_list=train_conf.token_list)
小湉湉's avatar
小湉湉 已提交
576 577


小湉湉's avatar
小湉湉 已提交
578 579 580 581
def evaluate(uid: str,
             source_lang: str="english",
             target_lang: str="english",
             use_pt_vocoder: bool=False,
小湉湉's avatar
小湉湉 已提交
582 583
             prefix: os.PathLike="./prompt/dev/",
             model_name: str="paddle_checkpoint_en",
小湉湉's avatar
小湉湉 已提交
584 585 586
             new_str: str="",
             prompt_decoding: bool=False,
             task_name: str=None):
P
pfZhu 已提交
587

小湉湉's avatar
小湉湉 已提交
588 589
    # get origin text and path of origin wav
    old_str, wav_path = read_data(uid=uid, prefix=prefix)
小湉湉's avatar
小湉湉 已提交
590

O
oyjxer 已提交
591 592 593
    if task_name == 'edit':
        new_str = new_str
    elif task_name == 'synthesize':
小湉湉's avatar
小湉湉 已提交
594
        new_str = old_str + new_str
O
oyjxer 已提交
595
    else:
小湉湉's avatar
小湉湉 已提交
596
        new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
小湉湉's avatar
小湉湉 已提交
597

P
pfZhu 已提交
598
    print('new_str is ', new_str)
小湉湉's avatar
小湉湉 已提交
599

小湉湉's avatar
小湉湉 已提交
600
    results_dict = get_wav(
小湉湉's avatar
小湉湉 已提交
601 602 603 604 605 606
        source_lang=source_lang,
        target_lang=target_lang,
        model_name=model_name,
        wav_path=wav_path,
        old_str=old_str,
        new_str=new_str,
小湉湉's avatar
小湉湉 已提交
607
        use_pt_vocoder=use_pt_vocoder)
P
pfZhu 已提交
608 609
    return results_dict

小湉湉's avatar
小湉湉 已提交
610

P
pfZhu 已提交
611
if __name__ == "__main__":
小湉湉's avatar
小湉湉 已提交
612
    # parse config and args
P
pfZhu 已提交
613
    args = parse_args()
小湉湉's avatar
小湉湉 已提交
614

小湉湉's avatar
小湉湉 已提交
615 616 617 618 619 620 621
    data_dict = evaluate(
        uid=args.uid,
        source_lang=args.source_lang,
        target_lang=args.target_lang,
        use_pt_vocoder=args.use_pt_vocoder,
        prefix=args.prefix,
        model_name=args.model_name,
P
pfZhu 已提交
622 623
        new_str=args.new_str,
        task_name=args.task_name)
小湉湉's avatar
小湉湉 已提交
624
    sf.write(args.output_name, data_dict['output'], samplerate=24000)
O
oyjxer 已提交
625
    print("finished...")