提交 445e3040 编写于 作者: 小湉湉's avatar 小湉湉

fix vocoder inference

上级 e522009d
...@@ -11,11 +11,7 @@ import paddle ...@@ -11,11 +11,7 @@ import paddle
import soundfile as sf import soundfile as sf
import torch import torch
from paddle import nn from paddle import nn
from sedit_arg_parser import parse_args from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from align import alignment from align import alignment
from align import alignment_zh from align import alignment_zh
...@@ -25,13 +21,17 @@ from collect_fn import build_collate_fn ...@@ -25,13 +21,17 @@ from collect_fn import build_collate_fn
from mlm import build_model_from_file from mlm import build_model_from_file
from read_text import load_num_sequence_text from read_text import load_num_sequence_text
from read_text import read_2col_text from read_text import read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import eval_durs
from utils import get_voc_out
from utils import is_chinese
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
def plot_mel_and_vocode_wav(wav_path: str, def get_wav(wav_path: str,
source_lang: str='english', source_lang: str='english',
target_lang: str='english', target_lang: str='english',
model_name: str="paddle_checkpoint_en", model_name: str="paddle_checkpoint_en",
...@@ -50,41 +50,23 @@ def plot_mel_and_vocode_wav(wav_path: str, ...@@ -50,41 +50,23 @@ def plot_mel_and_vocode_wav(wav_path: str,
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
if target_lang == 'english': if target_lang == 'english' and use_pt_vocoder:
if use_pt_vocoder: masked_feat = masked_feat.cpu().numpy()
output_feat = output_feat.cpu().numpy() masked_feat = torch.tensor(masked_feat, dtype=torch.float)
output_feat = torch.tensor(output_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).cpu().numpy() alt_wav = vocoder(masked_feat).cpu().numpy()
else:
replaced_wav = get_voc_out(output_feat)
elif target_lang == 'chinese': else:
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat) alt_wav = get_voc_out(masked_feat)
old_time_bdy = [hop_length * x for x in old_span_bdy] old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_bdy = [hop_length * x for x in new_span_bdy]
if target_lang == 'english':
wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_bdy[0]],
replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
wav_org[old_time_bdy[1]:]
])
data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} wav_replaced = np.concatenate(
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
elif target_lang == 'chinese': data_dict = {"origin": wav_org, "output": wav_replaced}
wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_bdy[1]:]
])
data_dict = {
"origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,
}
return data_dict, old_span_bdy return data_dict
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
...@@ -323,7 +305,7 @@ def get_phns_and_spans(wav_path: str, ...@@ -323,7 +305,7 @@ def get_phns_and_spans(wav_path: str,
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 # mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 # 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def duration_adjust_factor(orig_dur: List[int], def get_dur_adj_factor(orig_dur: List[int],
pred_dur: List[int], pred_dur: List[int],
phns: List[str]): phns: List[str]):
length = 0 length = 0
...@@ -376,7 +358,7 @@ def prep_feats_with_dur(wav_path: str, ...@@ -376,7 +358,7 @@ def prep_feats_with_dur(wav_path: str,
new_phns = new_phns + ['sp'] new_phns = new_phns + ['sp']
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang == "english" or target_lang == "chinese": if target_lang == "english" or target_lang == "chinese":
old_durs = evaluate_durations(old_phns, target_lang=source_lang) old_durs = eval_durs(old_phns, target_lang=source_lang)
else: else:
assert target_lang == "chinese" or target_lang == "english", \ assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..." "calculate duration_predict is not support for this language..."
...@@ -385,11 +367,11 @@ def prep_feats_with_dur(wav_path: str, ...@@ -385,11 +367,11 @@ def prep_feats_with_dur(wav_path: str,
if '[MASK]' in new_str: if '[MASK]' in new_str:
new_phns = old_phns new_phns = old_phns
span_to_add = span_to_repl span_to_add = span_to_repl
d_factor_left = duration_adjust_factor( d_factor_left = get_dur_adj_factor(
orig_dur=orig_old_durs[:span_to_repl[0]], orig_dur=orig_old_durs[:span_to_repl[0]],
pred_dur=old_durs[:span_to_repl[0]], pred_dur=old_durs[:span_to_repl[0]],
phns=old_phns[:span_to_repl[0]]) phns=old_phns[:span_to_repl[0]])
d_factor_right = duration_adjust_factor( d_factor_right = get_dur_adj_factor(
orig_dur=orig_old_durs[span_to_repl[1]:], orig_dur=orig_old_durs[span_to_repl[1]:],
pred_dur=old_durs[span_to_repl[1]:], pred_dur=old_durs[span_to_repl[1]:],
phns=old_phns[span_to_repl[1]:]) phns=old_phns[span_to_repl[1]:])
...@@ -397,15 +379,14 @@ def prep_feats_with_dur(wav_path: str, ...@@ -397,15 +379,14 @@ def prep_feats_with_dur(wav_path: str,
new_durs_adjusted = [d_factor * i for i in old_durs] new_durs_adjusted = [d_factor * i for i in old_durs]
else: else:
if duration_adjust: if duration_adjust:
d_factor = duration_adjust_factor( d_factor = get_dur_adj_factor(
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
print("d_factor:", d_factor)
d_factor = d_factor * 1.25 d_factor = d_factor * 1.25
else: else:
d_factor = 1 d_factor = 1
if target_lang == "english" or target_lang == "chinese": if target_lang == "english" or target_lang == "chinese":
new_durs = evaluate_durations(new_phns, target_lang=target_lang) new_durs = eval_durs(new_phns, target_lang=target_lang)
else: else:
assert target_lang == "chinese" or target_lang == "english", \ assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..." "calculate duration_predict is not support for this language..."
...@@ -616,7 +597,7 @@ def evaluate(uid: str, ...@@ -616,7 +597,7 @@ def evaluate(uid: str,
print('new_str is ', new_str) print('new_str is ', new_str)
results_dict, old_span = plot_mel_and_vocode_wav( results_dict = get_wav(
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
model_name=model_name, model_name=model_name,
......
import os import os
from typing import List
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -32,6 +31,7 @@ model_alias = { ...@@ -32,6 +31,7 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference", "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
} }
def is_chinese(ch): def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff': if u'\u4e00' <= ch <= u'\u9fff':
return True return True
...@@ -61,7 +61,7 @@ def get_voc_out(mel): ...@@ -61,7 +61,7 @@ def get_voc_out(mel):
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
voc_inference = voc_inference = get_voc_inference( voc_inference = get_voc_inference(
voc=args.voc, voc=args.voc,
voc_config=voc_config, voc_config=voc_config,
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
...@@ -164,7 +164,7 @@ def get_voc_inference( ...@@ -164,7 +164,7 @@ def get_voc_inference(
return voc_inference return voc_inference
def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
args = parse_args() args = parse_args()
if target_lang == 'english': if target_lang == 'english':
...@@ -176,10 +176,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): ...@@ -176,10 +176,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
elif target_lang == 'chinese': elif target_lang == 'chinese':
args.am = "fastspeech2_csmsc" args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if args.ngpu == 0: if args.ngpu == 0:
paddle.set_device("cpu") paddle.set_device("cpu")
...@@ -211,11 +211,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300): ...@@ -211,11 +211,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes] phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids phone_ids.append(vocab_size - 1)
phone_ids_new.append(vocab_size - 1) phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
_, d_outs, _, _ = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs phu_durs_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phu_durs_new = phu_durs_new.tolist()[:-1]
return phoneme_durations_new return phu_durs_new
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册