提交 445e3040 编写于 作者: 小湉湉's avatar 小湉湉

fix vocoder inference

上级 e522009d
......@@ -11,11 +11,7 @@ import paddle
import soundfile as sf
import torch
from paddle import nn
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from align import alignment
from align import alignment_zh
......@@ -25,13 +21,17 @@ from collect_fn import build_collate_fn
from mlm import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import eval_durs
from utils import get_voc_out
from utils import is_chinese
random.seed(0)
np.random.seed(0)
def plot_mel_and_vocode_wav(wav_path: str,
def get_wav(wav_path: str,
source_lang: str='english',
target_lang: str='english',
model_name: str="paddle_checkpoint_en",
......@@ -50,41 +50,23 @@ def plot_mel_and_vocode_wav(wav_path: str,
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
if target_lang == 'english':
if use_pt_vocoder:
output_feat = output_feat.cpu().numpy()
output_feat = torch.tensor(output_feat, dtype=torch.float)
if target_lang == 'english' and use_pt_vocoder:
masked_feat = masked_feat.cpu().numpy()
masked_feat = torch.tensor(masked_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).cpu().numpy()
else:
replaced_wav = get_voc_out(output_feat)
alt_wav = vocoder(masked_feat).cpu().numpy()
elif target_lang == 'chinese':
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat)
else:
alt_wav = get_voc_out(masked_feat)
old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_bdy = [hop_length * x for x in new_span_bdy]
if target_lang == 'english':
wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_bdy[0]],
replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
wav_org[old_time_bdy[1]:]
])
data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
wav_replaced = np.concatenate(
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
elif target_lang == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_bdy[1]:]
])
data_dict = {
"origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,
}
data_dict = {"origin": wav_org, "output": wav_replaced}
return data_dict, old_span_bdy
return data_dict
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
......@@ -323,7 +305,7 @@ def get_phns_and_spans(wav_path: str,
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def duration_adjust_factor(orig_dur: List[int],
def get_dur_adj_factor(orig_dur: List[int],
pred_dur: List[int],
phns: List[str]):
length = 0
......@@ -376,7 +358,7 @@ def prep_feats_with_dur(wav_path: str,
new_phns = new_phns + ['sp']
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang == "english" or target_lang == "chinese":
old_durs = evaluate_durations(old_phns, target_lang=source_lang)
old_durs = eval_durs(old_phns, target_lang=source_lang)
else:
assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..."
......@@ -385,11 +367,11 @@ def prep_feats_with_dur(wav_path: str,
if '[MASK]' in new_str:
new_phns = old_phns
span_to_add = span_to_repl
d_factor_left = duration_adjust_factor(
d_factor_left = get_dur_adj_factor(
orig_dur=orig_old_durs[:span_to_repl[0]],
pred_dur=old_durs[:span_to_repl[0]],
phns=old_phns[:span_to_repl[0]])
d_factor_right = duration_adjust_factor(
d_factor_right = get_dur_adj_factor(
orig_dur=orig_old_durs[span_to_repl[1]:],
pred_dur=old_durs[span_to_repl[1]:],
phns=old_phns[span_to_repl[1]:])
......@@ -397,15 +379,14 @@ def prep_feats_with_dur(wav_path: str,
new_durs_adjusted = [d_factor * i for i in old_durs]
else:
if duration_adjust:
d_factor = duration_adjust_factor(
d_factor = get_dur_adj_factor(
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
print("d_factor:", d_factor)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_lang == "english" or target_lang == "chinese":
new_durs = evaluate_durations(new_phns, target_lang=target_lang)
new_durs = eval_durs(new_phns, target_lang=target_lang)
else:
assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..."
......@@ -616,7 +597,7 @@ def evaluate(uid: str,
print('new_str is ', new_str)
results_dict, old_span = plot_mel_and_vocode_wav(
results_dict = get_wav(
source_lang=source_lang,
target_lang=target_lang,
model_name=model_name,
......
import os
from typing import List
from typing import Optional
import numpy as np
......@@ -32,6 +31,7 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
}
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
......@@ -61,7 +61,7 @@ def get_voc_out(mel):
# print("current vocoder: ", args.voc)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
voc_inference = voc_inference = get_voc_inference(
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
......@@ -164,7 +164,7 @@ def get_voc_inference(
return voc_inference
def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
args = parse_args()
if target_lang == 'english':
......@@ -176,10 +176,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
elif target_lang == 'chinese':
args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if args.ngpu == 0:
paddle.set_device("cpu")
......@@ -211,11 +211,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids
phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
_, d_outs, _, _ = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
phone_ids.append(vocab_size - 1)
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
_, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new
phu_durs_new = pre_d_outs * hop_length / fs
phu_durs_new = phu_durs_new.tolist()[:-1]
return phu_durs_new
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册