From 445e30407da4bbfc283be321103390694ed1b67c Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 15 Jun 2022 11:12:13 +0000 Subject: [PATCH] fix vocoder inference --- ernie-sat/inference.py | 87 +++++++++++++++++------------------------- ernie-sat/utils.py | 23 ++++++----- 2 files changed, 45 insertions(+), 65 deletions(-) diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py index c7fde9f..1760fa4 100644 --- a/ernie-sat/inference.py +++ b/ernie-sat/inference.py @@ -11,11 +11,7 @@ import paddle import soundfile as sf import torch from paddle import nn -from sedit_arg_parser import parse_args -from utils import build_vocoder_from_file -from utils import evaluate_durations -from utils import get_voc_out -from utils import is_chinese +from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from align import alignment from align import alignment_zh @@ -25,20 +21,24 @@ from collect_fn import build_collate_fn from mlm import build_model_from_file from read_text import load_num_sequence_text from read_text import read_2col_text -# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model +from sedit_arg_parser import parse_args +from utils import build_vocoder_from_file +from utils import eval_durs +from utils import get_voc_out +from utils import is_chinese random.seed(0) np.random.seed(0) -def plot_mel_and_vocode_wav(wav_path: str, - source_lang: str='english', - target_lang: str='english', - model_name: str="paddle_checkpoint_en", - old_str: str="", - new_str: str="", - use_pt_vocoder: bool=False, - non_autoreg: bool=True): +def get_wav(wav_path: str, + source_lang: str='english', + target_lang: str='english', + model_name: str="paddle_checkpoint_en", + old_str: str="", + new_str: str="", + use_pt_vocoder: bool=False, + non_autoreg: bool=True): wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( source_lang=source_lang, target_lang=target_lang, @@ -50,41 +50,23 @@ def plot_mel_and_vocode_wav(wav_path: str, masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] - if target_lang == 'english': - if use_pt_vocoder: - output_feat = output_feat.cpu().numpy() - output_feat = torch.tensor(output_feat, dtype=torch.float) - vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') - replaced_wav = vocoder(output_feat).cpu().numpy() - else: - replaced_wav = get_voc_out(output_feat) + if target_lang == 'english' and use_pt_vocoder: + masked_feat = masked_feat.cpu().numpy() + masked_feat = torch.tensor(masked_feat, dtype=torch.float) + vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') + alt_wav = vocoder(masked_feat).cpu().numpy() - elif target_lang == 'chinese': - replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat) + else: + alt_wav = get_voc_out(masked_feat) old_time_bdy = [hop_length * x for x in old_span_bdy] - new_time_bdy = [hop_length * x for x in new_span_bdy] - - if target_lang == 'english': - wav_org_replaced_paddle_voc = np.concatenate([ - wav_org[:old_time_bdy[0]], - replaced_wav[new_time_bdy[0]:new_time_bdy[1]], - wav_org[old_time_bdy[1]:] - ]) - data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) - elif target_lang == 'chinese': - wav_org_replaced_only_mask_fst2_voc = np.concatenate([ - wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc, - wav_org[old_time_bdy[1]:] - ]) - data_dict = { - "origin": wav_org, - "output": wav_org_replaced_only_mask_fst2_voc, - } + data_dict = {"origin": wav_org, "output": wav_replaced} - return data_dict, old_span_bdy + return data_dict def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): @@ -323,9 +305,9 @@ def get_phns_and_spans(wav_path: str, # mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 # 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 -def duration_adjust_factor(orig_dur: List[int], - pred_dur: List[int], - phns: List[str]): +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): length = 0 factor_list = [] for orig, pred, phn in zip(orig_dur, pred_dur, phns): @@ -376,7 +358,7 @@ def prep_feats_with_dur(wav_path: str, new_phns = new_phns + ['sp'] # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 if target_lang == "english" or target_lang == "chinese": - old_durs = evaluate_durations(old_phns, target_lang=source_lang) + old_durs = eval_durs(old_phns, target_lang=source_lang) else: assert target_lang == "chinese" or target_lang == "english", \ "calculate duration_predict is not support for this language..." @@ -385,11 +367,11 @@ def prep_feats_with_dur(wav_path: str, if '[MASK]' in new_str: new_phns = old_phns span_to_add = span_to_repl - d_factor_left = duration_adjust_factor( + d_factor_left = get_dur_adj_factor( orig_dur=orig_old_durs[:span_to_repl[0]], pred_dur=old_durs[:span_to_repl[0]], phns=old_phns[:span_to_repl[0]]) - d_factor_right = duration_adjust_factor( + d_factor_right = get_dur_adj_factor( orig_dur=orig_old_durs[span_to_repl[1]:], pred_dur=old_durs[span_to_repl[1]:], phns=old_phns[span_to_repl[1]:]) @@ -397,15 +379,14 @@ def prep_feats_with_dur(wav_path: str, new_durs_adjusted = [d_factor * i for i in old_durs] else: if duration_adjust: - d_factor = duration_adjust_factor( + d_factor = get_dur_adj_factor( orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) - print("d_factor:", d_factor) d_factor = d_factor * 1.25 else: d_factor = 1 if target_lang == "english" or target_lang == "chinese": - new_durs = evaluate_durations(new_phns, target_lang=target_lang) + new_durs = eval_durs(new_phns, target_lang=target_lang) else: assert target_lang == "chinese" or target_lang == "english", \ "calculate duration_predict is not support for this language..." @@ -616,7 +597,7 @@ def evaluate(uid: str, print('new_str is ', new_str) - results_dict, old_span = plot_mel_and_vocode_wav( + results_dict = get_wav( source_lang=source_lang, target_lang=target_lang, model_name=model_name, diff --git a/ernie-sat/utils.py b/ernie-sat/utils.py index 672c115..1b74dc0 100644 --- a/ernie-sat/utils.py +++ b/ernie-sat/utils.py @@ -1,5 +1,4 @@ import os -from typing import List from typing import Optional import numpy as np @@ -32,6 +31,7 @@ model_alias = { "paddlespeech.t2s.models.parallel_wavegan:PWGInference", } + def is_chinese(ch): if u'\u4e00' <= ch <= u'\u9fff': return True @@ -61,7 +61,7 @@ def get_voc_out(mel): # print("current vocoder: ", args.voc) with open(args.voc_config) as f: voc_config = CfgNode(yaml.safe_load(f)) - voc_inference = voc_inference = get_voc_inference( + voc_inference = get_voc_inference( voc=args.voc, voc_config=voc_config, voc_ckpt=args.voc_ckpt, @@ -164,7 +164,7 @@ def get_voc_inference( return voc_inference -def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): +def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300): args = parse_args() if target_lang == 'english': @@ -176,10 +176,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): elif target_lang == 'chinese': args.am = "fastspeech2_csmsc" - args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" - args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" if args.ngpu == 0: paddle.set_device("cpu") @@ -211,11 +211,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] phone_ids = [vocab_phones[item] for item in phonemes] - phone_ids_new = phone_ids - phone_ids_new.append(vocab_size - 1) - phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) - _, d_outs, _, _ = am.inference(phone_ids_new, spk_id=None, spk_emb=None) + phone_ids.append(vocab_size - 1) + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None) pre_d_outs = d_outs - phoneme_durations_new = pre_d_outs * hop_length / fs - phoneme_durations_new = phoneme_durations_new.tolist()[:-1] - return phoneme_durations_new + phu_durs_new = pre_d_outs * hop_length / fs + phu_durs_new = phu_durs_new.tolist()[:-1] + return phu_durs_new -- GitLab