diff --git a/ernie-sat/.DS_Store b/ernie-sat/.DS_Store deleted file mode 100644 index d786441fb926dbb0a14d94f455c5814ada534d7f..0000000000000000000000000000000000000000 Binary files a/ernie-sat/.DS_Store and /dev/null differ diff --git a/ernie-sat/README.md b/ernie-sat/README.md index 704abc64567f036927dc4dfc441f2922e0cfb054..51c6244e64330d8e41982e93315f0d27b22e8942 100644 --- a/ernie-sat/README.md +++ b/ernie-sat/README.md @@ -11,7 +11,7 @@ ERNIE-SAT中我们提出了两项创新: ### 1.安装飞桨 -我们的代码基于 Paddle(version>=2.0) +本项目的代码基于 Paddle(version>=2.0) ### 2.预训练模型 @@ -23,7 +23,7 @@ ERNIE-SAT中我们提出了两项创新: ### 3.下载 -1. 我们使用parallel wavegan作为声码器(vocoder): +1. 本项目使用parallel wavegan作为声码器(vocoder): - [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) 创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压 @@ -34,7 +34,7 @@ cd download unzip pwg_aishell3_ckpt_0.5.zip ``` - 2. 我们使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: + 2. 本项目使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用 - [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用 @@ -48,7 +48,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip ### 4.推理 -我们目前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 +本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 注:当前采用的声码器版本与[模型训练时版本](https://github.com/kan-bayashi/ParallelWaveGAN)在英文上存在差异,您可使用模型训练时版本作为您的声码器,模型将在后续更新中升级。 我们提供特定音频文件, 以及其对应的文本、音素相关文件: diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..39e74b16fbd1f0bb932d342c4219c9329fe54824 --- /dev/null +++ b/ernie-sat/inference.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 + +import os +from pathlib import Path +import paddle +import math +import string +import numpy as np + +from read_text import read_2column_text,load_num_sequence_text +from utils import sentence2phns,get_voc_out, evaluate_durations +import librosa +import random +import soundfile as sf +import sys +import pickle +from model_paddle import build_model_from_file + +from sedit_arg_parser import parse_args +import argparse +from typing import Collection +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask + + +random.seed(0) +np.random.seed(0) + +def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str, vocoder,duration_preditor_path,sid=None, non_autoreg=True): + wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + model_name, + wav_path, + old_str, + new_str, + duration_preditor_path, + use_teacher_forcing=non_autoreg, + sid=sid + ) + + + masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[1]].detach().float().cpu().numpy() + + if target_language == 'english': + output_feat_np = output_feat.detach().float().cpu().numpy() + replaced_wav_paddle_voc = get_voc_out(output_feat_np, target_language) + replaced_wav = replaced_wav_paddle_voc + + elif target_language == 'chinese': + assert old_span_boundary[1] == new_span_boundary[0], "old_span_boundary[1] is not same with new_span_boundary[0]." + output_feat_np = output_feat.detach().float().cpu().numpy() + replaced_wav = get_voc_out(output_feat_np) + + replaced_wav_only_mask = get_voc_out(masked_feat) + replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_language) + + + old_time_boundary = [hop_length * x for x in old_span_boundary] + new_time_boundary = [hop_length * x for x in new_span_boundary] + + wav_org_replaced = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]]) + + if target_language == 'english': + # new add to test paddle vocoder + wav_org_replaced_paddle_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_paddle_voc[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]]) + + data_dict = { + "origin":wav_org, + "output":wav_org_replaced_paddle_voc} + + elif target_language == 'chinese': + wav_org_replaced_only_mask = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask, wav_org[old_time_boundary[1]:]]) + wav_org_replaced_only_mask_fst2_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[old_time_boundary[1]:]]) + data_dict = { + "origin":wav_org, + "output": wav_org_replaced_only_mask_fst2_voc,} + + return data_dict, old_span_boundary + + +def load_model(model_name): + config_path='./pretrained_model/{}/config.yaml'.format(model_name) + model_path = './pretrained_model/{}/model.pdparams'.format(model_name) + + mlm_model, args = build_model_from_file(config_file=config_path, + model_file=model_path) + return mlm_model, args + + +def read_data(uid,prefix): + mfa_text = read_2column_text(prefix+'/text')[uid] + mfa_wav_path = read_2column_text(prefix+'/wav.scp')[uid] + if 'mnt' not in mfa_wav_path: + mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path + return mfa_text, mfa_wav_path + +def get_align_data(uid,prefix): + mfa_path = prefix+"mfa_" + mfa_text = read_2column_text(mfa_path+'text')[uid] + mfa_start = load_num_sequence_text(mfa_path+'start',loader_type='text_float')[uid] + mfa_end = load_num_sequence_text(mfa_path+'end',loader_type='text_float')[uid] + mfa_wav_path = read_2column_text(mfa_path+'wav.scp')[uid] + return mfa_text, mfa_start, mfa_end, mfa_wav_path + + +def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced): + align_start=paddle.to_tensor(mfa_start).unsqueeze(0) + align_end =paddle.to_tensor(mfa_end).unsqueeze(0) + align_start = paddle.floor(fs*align_start/hop_length).int() + align_end = paddle.floor(fs*align_end/hop_length).int() + if span_tobe_replaced[0]>=len(mfa_start): + span_boundary = [align_end[0].tolist()[-1],align_end[0].tolist()[-1]] + else: + span_boundary=[align_start[0].tolist()[span_tobe_replaced[0]],align_end[0].tolist()[span_tobe_replaced[1]-1]] + return span_boundary + + +def gen_phns(zh_mapping, phns): + new_phns = [] + for x in phns: + if x in zh_mapping.keys(): + new_phns.extend(zh_mapping[x].split(" ")) + else: + new_phns.extend(['']) + return new_phns + + +def get_mapping(phn_mapping="./phn_mapping.txt"): + zh_mapping = {} + with open(phn_mapping, "r") as f: + for line in f: + pd_phn = line.split(" ")[0] + if pd_phn not in zh_mapping.keys(): + zh_mapping[pd_phn] = " ".join(line.split()[1:]) + return zh_mapping + + +def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, target_language): + zh_mapping = get_mapping() + old_str = old_str.strip() + new_str = new_str.strip() + words = [] + for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: + old_str = old_str.replace(pun, ' ') + new_str = new_str.replace(pun, ' ') + + + append_new_str = (old_str == new_str[:len(old_str)]) + print("append_new_str: ", append_new_str) + old_phns, mfa_start, mfa_end = [], [], [] + mfa_text, mfa_start, mfa_end, mfa_wav_path = get_align_data(uid, prefix) + old_phns = mfa_text.split(" ") + + if append_new_str: + if source_language != target_language: + is_cross_lingual = True + else: + is_cross_lingual = False + + new_str_origin = new_str[:len(old_str)] + new_str_append = new_str[len(old_str):] + if is_cross_lingual: + if source_language == "english" and target_language == "chinese": + new_phns_origin = old_phns + new_phns_append, _ = sentence2phns(new_str_append, "zh") + + elif source_language=="chinese" and target_language == "english": + new_phns_origin = old_phns + new_phns_append, _ = sentence2phns(new_str_append, "en") + else: + assert target_language == "chinese" or target_language == "english", "cloning is not support for this language, please check it." + + else: + + if source_language == target_language and target_language == "english": + new_phns_origin = old_phns + new_phns_append, _ = sentence2phns(new_str_append, "en") + + elif source_language == target_language and target_language == "chinese": + new_phns_origin = old_phns + new_phns_append, _ = sentence2phns(new_str_append, "zh") + else: + assert source_language == target_language, "source language is not same with target language..." + + if target_language == "chinese": + new_phns_append = gen_phns(zh_mapping, new_phns_append) + + new_phns = new_phns_origin + new_phns_append + + span_tobe_replaced = [len(old_phns),len(old_phns)] + span_tobe_added = [len(old_phns),len(new_phns)] + + + else: + if source_language == target_language and target_language == "english": + new_phns, _ = sentence2phns(new_str, "en") + + elif source_language == target_language and target_language == "chinese": + new_phns, _ = sentence2phns(new_str, "zh") + new_phns = gen_phns(zh_mapping, new_phns) + + + else: + assert source_language == target_language, "source language is not same with target language..." + + while(new_phns[-1] == 'sp'): + new_phns.pop() + + while(new_phns[0] == 'sp'): + new_phns.pop(0) + + span_tobe_replaced = [0,len(old_phns)-1] + span_tobe_added = [0,len(new_phns)-1] + new_phns_left = [] + left_index = 0 + sp_count = 0 + + for idx, phn in enumerate(old_phns): + if phn == "sp": + sp_count += 1 + new_phns_left.append('sp') + else: + idx = idx - sp_count + if phn == new_phns[idx]: + left_index += 1 + new_phns_left.append(phn) + else: + span_tobe_replaced[0] = len(new_phns_left) + span_tobe_added[0] = len(new_phns_left) + break + + right_index = 0 + new_phns_middle = [] + new_phns_right = [] + sp_count = 0 + word2phns_max_index = len(old_phns) + new_word2phns_max_index = len(new_phns) + + for idx, phn in enumerate(old_phns[::-1]): + cur_idx = len(old_phns) - 1 - idx + if phn == "sp": + sp_count += 1 + new_phns_right = ['sp']+new_phns_right + else: + cur_idx = new_word2phns_max_index - (word2phns_max_index - cur_idx -sp_count) + if phn == new_phns[cur_idx]: + right_index -= 1 + new_phns_right = [phn] + new_phns_right + + else: + span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) + new_phns_middle = new_phns[left_index:right_index] + span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) + if len(new_phns_middle) == 0: + span_tobe_added[1] = min(span_tobe_added[1]+1, len(new_phns)) + span_tobe_added[0] = max(0, span_tobe_added[0]-1) + span_tobe_replaced[0] = max(0, span_tobe_replaced[0]-1) + span_tobe_replaced[1] = min(span_tobe_replaced[1]+1, len(old_phns)) + break + + new_phns = new_phns_left+new_phns_middle+new_phns_right + + + return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added + + + +def duration_adjust_factor(original_dur, pred_dur, phns): + length = 0 + accumulate = 0 + factor_list = [] + for ori,pred,phn in zip(original_dur, pred_dur,phns): + if pred==0 or phn=='sp': + continue + else: + factor_list.append(ori/pred) + factor_list = np.array(factor_list) + factor_list.sort() + if len(factor_list)<5: + return 1 + + length = 2 + return np.average(factor_list[length:-length]) + + + +def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, new_str, wav_path,duration_preditor_path,sid=None, mask_reconstruct=False,duration_adjust=True,start_end_sp=False, train_args=None): + wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) + fs = train_args.feats_extract_conf['fs'] + hop_length = train_args.feats_extract_conf['hop_length'] + + mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, target_language) + + if start_end_sp: + if new_phns[-1]!='sp': + new_phns = new_phns+['sp'] + + + if target_language == "english": + old_durations = evaluate_durations(old_phns, target_language=target_language) + + elif target_language =="chinese": + + if source_language == "english": + old_durations = evaluate_durations(old_phns, target_language=source_language) + + elif source_language == "chinese": + old_durations = evaluate_durations(old_phns, target_language=source_language) + + else: + assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." + + + + original_old_durations = [e-s for e,s in zip(mfa_end, mfa_start)] + if '[MASK]' in new_str: + new_phns = old_phns + span_tobe_added = span_tobe_replaced + d_factor_left = duration_adjust_factor(original_old_durations[:span_tobe_replaced[0]],old_durations[:span_tobe_replaced[0]], old_phns[:span_tobe_replaced[0]]) + d_factor_right = duration_adjust_factor(original_old_durations[span_tobe_replaced[1]:],old_durations[span_tobe_replaced[1]:], old_phns[span_tobe_replaced[1]:]) + d_factor = (d_factor_left+d_factor_right)/2 + new_durations_adjusted = [d_factor*i for i in old_durations] + else: + if duration_adjust: + d_factor = duration_adjust_factor(original_old_durations,old_durations, old_phns) + d_factor_paddle = duration_adjust_factor(original_old_durations,old_durations, old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_language == "english": + new_durations = evaluate_durations(new_phns, target_language=target_language) + + + elif target_language =="chinese": + new_durations = evaluate_durations(new_phns, target_language=target_language) + + new_durations_adjusted = [d_factor*i for i in new_durations] + + if span_tobe_replaced[0]=len(mfa_start): + left_index = len(wav_org) + right_index = left_index + else: + left_index = int(np.floor(mfa_start[span_tobe_replaced[0]]*fs)) + right_index = int(np.ceil(mfa_end[span_tobe_replaced[1]-1]*fs)) + new_blank_wav = np.zeros((int(np.ceil(new_span_duration_sum*fs)),), dtype=wav_org.dtype) + new_wav_org = np.concatenate([wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) + + + # 4. get old and new mel span to be mask + old_span_boundary = get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] + new_span_boundary=get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, hop_length, span_tobe_added) # [92, 174] + + + return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary + +def prepare_features(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor, wav_path, old_str,new_str,duration_preditor_path, sid=None,duration_adjust=True,start_end_sp=False, +mask_reconstruct=False, train_args=None): + wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, + new_str, wav_path,duration_preditor_path,sid=sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp,mask_reconstruct=mask_reconstruct, train_args = train_args) + speech = np.array(wav_org,dtype=np.float32) + align_start=np.array(mfa_start) + align_end =np.array(mfa_end) + token_to_id = {item: i for i, item in enumerate(train_args.token_list)} + text = np.array(list(map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) + # print('unk id is', token_to_id['']) + # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) + span_boundary = np.array(new_span_boundary) + batch=[('1', {"speech":speech,"align_start":align_start,"align_end":align_end,"text":text,"span_boundary":span_boundary})] + + return batch, old_span_boundary, new_span_boundary + +def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False,duration_adjust=True,start_end_sp=False, train_args=None): + fs, hop_length = train_args.feats_extract_conf['fs'], train_args.feats_extract_conf['hop_length'] + + batch,old_span_boundary,new_span_boundary = prepare_features(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor,wav_path,old_str,new_str,duration_preditor_path, sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args=train_args) + + feats = pickle.load(open('tmp/tmp_pkl.'+str(uid), 'rb')) + tmp = feats['speech'][0] + + # print('feats end') + # wav_len * 80 + # set_all_random_seed(9999) + if 'text_masked_position' in feats.keys(): + feats.pop('text_masked_position') + + for k, v in feats.items(): + feats[k] = paddle.to_tensor(v) + rtn = mlm_model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=use_teacher_forcing) + output = rtn['feat_gen'] + if 0 in output[0].shape and 0 not in output[-1].shape: + output_feat = paddle.concat(output[1:-1]+[output[-1].squeeze()], axis=0).cpu() + elif 0 not in output[0].shape and 0 in output[-1].shape: + output_feat = paddle.concat([output[0].squeeze()]+output[1:-1], axis=0).cpu() + elif 0 in output[0].shape and 0 in output[-1].shape: + output_feat = paddle.concat(output[1:-1], axis=0).cpu() + else: + output_feat = paddle.concat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], axis=0).cpu() + + wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) + origin_speech = paddle.to_tensor(np.array(wav_org,dtype=np.float32)).unsqueeze(0) + speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0) + return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length + + +def get_mlm_output(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False, dynamic_eval=(0,0),duration_adjust=True,start_end_sp=False): + mlm_model,train_args = load_model(model_name) + mlm_model.eval() + processor = None + collate_fn = None + + return decode_with_model(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing, + duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args = train_args) + +def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vocoder, prefix='dump/raw/dev', model_name="conformer", old_str="",new_str="",prompt_decoding=False,dynamic_eval=(0,0), task_name = None): + + duration_preditor_path = None + spemd = None + full_origin_str,wav_path = read_data(uid, prefix) + + new_str = new_str if task_name == 'edit' else full_origin_str + new_str + print('new_str is ', new_str) + + if not old_str: + old_str = full_origin_str + + results_dict, old_span = plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd) + return results_dict + +if __name__ == "__main__": + args = parse_args() + print(args) + data_dict = test_vctk(args.uid, + args.clone_uid, + args.clone_prefix, + args.source_language, + args.target_language, + None, + args.prefix, + args.model_name, + new_str=args.new_str, + task_name=args.task_name) + sf.write('./wavs/%s' % args.output_name, data_dict['output'], samplerate=24000) + # exit() diff --git a/ernie-sat/prompt/dev/mfa_end b/ernie-sat/prompt/dev/mfa_end index 70c1237a09a56ba06791b9f05c0f694a3816336b..8772d7e4bd2f5f9c74bb904aceaea5d06e035297 100644 --- a/ernie-sat/prompt/dev/mfa_end +++ b/ernie-sat/prompt/dev/mfa_end @@ -1,3 +1,3 @@ +p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525 Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125 p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825 -p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525 diff --git a/ernie-sat/prompt/dev/mfa_start b/ernie-sat/prompt/dev/mfa_start index a975f8aafeab902a2da5d981f74126d1debf2290..fac9716128f5ec175fe0df8acf22969057ef2ab5 100644 --- a/ernie-sat/prompt/dev/mfa_start +++ b/ernie-sat/prompt/dev/mfa_start @@ -1,3 +1,3 @@ -Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 +Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 diff --git a/ernie-sat/prompt/dev/mfa_text b/ernie-sat/prompt/dev/mfa_text index 68a33eb640acd6b1b9101e5bc79f50bd5c778272..903e2444a0c55e694fff9a7edb92398f04bfefec 100644 --- a/ernie-sat/prompt/dev/mfa_text +++ b/ernie-sat/prompt/dev/mfa_text @@ -1,3 +1,3 @@ +p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp -p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp diff --git a/ernie-sat/prompt/dev/mfa_wav.scp b/ernie-sat/prompt/dev/mfa_wav.scp index ad5b9d9cae8ccce228f7a922227ea641fdc8fc0e..eb0e8e48d5c50fe65271a8039bdd24628fa9daf5 100644 --- a/ernie-sat/prompt/dev/mfa_wav.scp +++ b/ernie-sat/prompt/dev/mfa_wav.scp @@ -1,3 +1,3 @@ -Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav p243_new ../../prompt_wav/p243_313.wav +Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav p299_096 ../../prompt_wav/p299_096.wav diff --git a/ernie-sat/prompt/dev/text b/ernie-sat/prompt/dev/text index 026aa9ad4f7952e338dd2e8a1fedf68cfd709f23..f79cdcb42ea0c6bacc124c6ccc668a0891eca33a 100644 --- a/ernie-sat/prompt/dev/text +++ b/ernie-sat/prompt/dev/text @@ -1,3 +1,3 @@ -Prompt_003_new This was not the show for me. p243_new For that reason cover should not be given. +Prompt_003_new This was not the show for me. p299_096 We are trying to establish a date. diff --git a/ernie-sat/prompt/dev/wav.scp b/ernie-sat/prompt/dev/wav.scp index c0f8a1c7fdc1c76838b68eeeca4c4a5c416f4150..eb0e8e48d5c50fe65271a8039bdd24628fa9daf5 100644 --- a/ernie-sat/prompt/dev/wav.scp +++ b/ernie-sat/prompt/dev/wav.scp @@ -1,3 +1,3 @@ +p243_new ../../prompt_wav/p243_313.wav Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav p299_096 ../../prompt_wav/p299_096.wav -p243_new ../../prompt_wav/p243_313.wav diff --git a/ernie-sat/run_clone_en_to_zh.sh b/ernie-sat/run_clone_en_to_zh.sh index 85b013c7612979e3eb60f318c73f398851663544..2a50ef1a1bc242cf9f3f06ae1ae9d4e74a3f7bce 100644 --- a/ernie-sat/run_clone_en_to_zh.sh +++ b/ernie-sat/run_clone_en_to_zh.sh @@ -1,15 +1,15 @@ -# en --> zh 的 clone -python sedit_inference_0520.py \ +# en --> zh 的 语音合成 +# 根据Prompt_003_new对应的语音: This was not the show for me. 来合成: '今天天气很好' + +python inference.py \ --task_name cross-lingual_clone \ --model_name paddle_checkpoint_ench \ --uid Prompt_003_new \ --new_str '今天天气很好' \ --prefix ./prompt/dev/ \ ---clone_prefix ./prompt/dev_aishell3/ \ ---clone_uid SSB07510054 \ --source_language english \ --target_language chinese \ ---output_name task_cross_lingual_pred.wav \ +--output_name pred_zh.wav \ --voc pwgan_aishell3 \ --voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ diff --git a/ernie-sat/run_gen_en.sh b/ernie-sat/run_gen_en.sh index c89431c0653c01dc79a94fe0088c05754a2a521c..cb63a4cc44c8998d6d0ad449af865b8c3607a148 100644 --- a/ernie-sat/run_gen_en.sh +++ b/ernie-sat/run_gen_en.sh @@ -1,26 +1,7 @@ # 纯英文的语音合成 -# python sedit_inference_0518.py \ -# --task_name synthesize \ -# --model_name paddle_checkpoint_en \ -# --uid p323_083 \ -# --new_str 'I enjoy my life.' \ -# --prefix ./prompt/dev/ \ -# --source_language english \ -# --target_language english \ -# --output_name pred.wav \ -# --voc pwgan_aishell3 \ -# --voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ -# --voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ -# --voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ -# --am fastspeech2_ljspeech \ -# --am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ -# --am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ -# --am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ -# --phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt +# 根据p299_096对应的语音: This was not the show for me. 来合成: 'I enjoy my life.' - -# 纯英文的语音合成 -python sedit_inference_0520.py \ +python inference.py \ --task_name synthesize \ --model_name paddle_checkpoint_en \ --uid p299_096 \ @@ -28,7 +9,7 @@ python sedit_inference_0520.py \ --prefix ./prompt/dev/ \ --source_language english \ --target_language english \ ---output_name task_synthesize_pred.wav \ +--output_name pred.wav \ --voc pwgan_aishell3 \ --voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ diff --git a/ernie-sat/run_sedit_en.sh b/ernie-sat/run_sedit_en.sh index c3d5a7457aa1fa38732fc75eb1823e985cd29b47..ebece42c26233a62a2a3925c919182b86c37cf5f 100644 --- a/ernie-sat/run_sedit_en.sh +++ b/ernie-sat/run_sedit_en.sh @@ -1,5 +1,7 @@ # 纯英文的语音编辑 -python sedit_inference_0520.py \ +# 将p243_new对应的原始语音: For that reason cover should not be given. 编辑成'for that reason cover is impossible to be given.'对应的语音 + +python inference.py \ --task_name edit \ --model_name paddle_checkpoint_en \ --uid p243_new \ @@ -7,7 +9,7 @@ python sedit_inference_0520.py \ --prefix ./prompt/dev/ \ --source_language english \ --target_language english \ ---output_name task_edit_pred.wav \ +--output_name pred.wav \ --voc pwgan_aishell3 \ --voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ diff --git a/ernie-sat/sedit_inference_0520.py b/ernie-sat/sedit_inference_0520.py deleted file mode 100644 index 09ca3e567a6d2925091bc619cc8c19410c637ea6..0000000000000000000000000000000000000000 --- a/ernie-sat/sedit_inference_0520.py +++ /dev/null @@ -1,1086 +0,0 @@ -#!/usr/bin/env python3 - -"""Script to run the inference of text-to-speeech model.""" - -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "3" - -from parallel_wavegan.utils import download_pretrained_model -from pathlib import Path -import paddle -import soundfile -import os -import math -import string -import numpy as np - -from espnet2.tasks.mlm import MLMTask -from read_text import read_2column_text,load_num_sequence_text -from util import sentence2phns,get_voc_out, evaluate_durations -import librosa -import random -from ipywidgets import widgets -import IPython.display as ipd -import soundfile as sf -import sys -import pickle -from model_paddle import build_model_from_file - -from sedit_arg_parser import parse_args -import argparse -from typing import Collection -from typing import Dict -from typing import List -from typing import Tuple -from typing import Union - -from paddlespeech.t2s.datasets.get_feats import LogMelFBank -from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask - -duration_path_dict = { - "ljspeech":"/mnt/home/v_baihe/projects/espnet/egs2/ljspeech/tts1/exp/kan-bayashi/ljspeech_tts_train_conformer_fastspeech2_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth", - "vctk": "/mnt/home/v_baihe/projects/espnet/egs2/vctk/tts1/exp/kan-bayashi/vctk_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth", - # "ljspeech":"/home/mnt2/zz/workspace/work/espnet_richard_infer/egs2/ljspeech/tts1/exp/kan-bayashi/ljspeech_tts_train_conformer_fastspeech2_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth", - # "vctk": "/home/mnt2/zz/workspace/work/espnet_richard_infer/egs2/vctk/tts1/exp/kan-bayashi/vctk_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth", - "vctk_unseen":"/mnt/home/v_baihe/projects/espnet/egs2/vctk/tts1/exp/tts_train_fs2_raw_phn_tacotron_g2p_en_no_space/train.loss.ave_5best.pth", - "libritts":"/mnt/home/v_baihe/projects/espnet/egs2/libritts/tts1/exp/kan-bayashi/libritts_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss/train.loss.ave_5best.pth" -} - -random.seed(0) -np.random.seed(0) - - -def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str, vocoder,duration_preditor_path,sid=None, non_autoreg=True): - wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - old_str, - new_str, - duration_preditor_path, - use_teacher_forcing=non_autoreg, - sid=sid - ) - - masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[1]].detach().float().cpu().numpy() - - if target_language == 'english': - output_feat_np = output_feat.detach().float().cpu().numpy() - replaced_wav_paddle_voc = get_voc_out(output_feat_np, target_language) - - elif target_language == 'chinese': - output_feat_np = output_feat.detach().float().cpu().numpy() - replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_language) - - - old_time_boundary = [hop_length * x for x in old_span_boundary] - new_time_boundary = [hop_length * x for x in new_span_boundary] - - - if target_language == 'english': - wav_org_replaced_paddle_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_paddle_voc[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]]) - - data_dict = {"origin":wav_org, - "output":wav_org_replaced_paddle_voc} - - elif target_language == 'chinese': - wav_org_replaced_only_mask_fst2_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[old_time_boundary[1]:]]) - data_dict = {"origin":wav_org, - "output": wav_org_replaced_only_mask_fst2_voc,} - - return data_dict, old_span_boundary - - - -def load_vocoder(vocoder_tag="parallel_wavegan/libritts_parallel_wavegan.v1"): - vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") - vocoder_file = download_pretrained_model(vocoder_tag) - vocoder_config = Path(vocoder_file).parent / "config.yml" - - vocoder = TTSTask.build_vocoder_from_file( - vocoder_config, vocoder_file, None, 'cpu' - ) - return vocoder - -def load_model(model_name): - config_path='./pretrained_model/{}/config.yaml'.format(model_name) - model_path = './pretrained_model/{}/model.pdparams'.format(model_name) - - mlm_model, args = build_model_from_file(config_file=config_path, - model_file=model_path) - return mlm_model, args - - -def read_data(uid,prefix): - mfa_text = read_2column_text(prefix+'/text')[uid] - mfa_wav_path = read_2column_text(prefix+'/wav.scp')[uid] - if 'mnt' not in mfa_wav_path: - mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path - return mfa_text, mfa_wav_path - -def get_align_data(uid,prefix): - mfa_path = prefix+"mfa_" - mfa_text = read_2column_text(mfa_path+'text')[uid] - mfa_start = load_num_sequence_text(mfa_path+'start',loader_type='text_float')[uid] - mfa_end = load_num_sequence_text(mfa_path+'end',loader_type='text_float')[uid] - mfa_wav_path = read_2column_text(mfa_path+'wav.scp')[uid] - return mfa_text, mfa_start, mfa_end, mfa_wav_path - - - -def get_fs2_model(model_name): - model, config = TTSTask.build_model_from_file(model_file=model_name) - processor = TTSTask.build_preprocess_fn(config, train=False) - return model, processor - -def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced): - align_start=paddle.to_tensor(mfa_start).unsqueeze(0) - align_end =paddle.to_tensor(mfa_end).unsqueeze(0) - align_start = paddle.floor(fs*align_start/hop_length).int() - align_end = paddle.floor(fs*align_end/hop_length).int() - if span_tobe_replaced[0]>=len(mfa_start): - span_boundary = [align_end[0].tolist()[-1],align_end[0].tolist()[-1]] - else: - span_boundary=[align_start[0].tolist()[span_tobe_replaced[0]],align_end[0].tolist()[span_tobe_replaced[1]-1]] - return span_boundary - - -def get_mapping(phn_mapping="./phn_mapping.txt"): - zh_mapping = {} - with open(phn_mapping, "r") as f: - for line in f: - pd_phn = line.split(" ")[0] - if pd_phn not in zh_mapping.keys(): - zh_mapping[pd_phn] = " ".join(line.split()[1:]) - return zh_mapping - - -def gen_phns(zh_mapping, phns): - new_phns = [] - for x in phns: - if x in zh_mapping.keys(): - new_phns.extend(zh_mapping[x].split(" ")) - else: - new_phns.extend(['']) - return new_phns - -def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, target_language): - zh_mapping = get_mapping() - old_str = old_str.strip() - new_str = new_str.strip() - words = [] - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: - old_str = old_str.replace(pun, ' ') - new_str = new_str.replace(pun, ' ') - - - append_new_str = (old_str == new_str[:len(old_str)]) - print("append_new_str: ", append_new_str) - old_phns, mfa_start, mfa_end = [], [], [] - mfa_text, mfa_start, mfa_end, mfa_wav_path = get_align_data(uid, prefix) - old_phns = mfa_text.split(" ") - - if append_new_str: - if source_language != target_language: - is_cross_lingual = True - else: - is_cross_lingual = False - - new_str_origin = new_str[:len(old_str)] - new_str_append = new_str[len(old_str):] - if is_cross_lingual: - if source_language == "english" and target_language == "chinese": - new_phns_origin = old_phns - new_phns_append, _ = sentence2phns(new_str_append, "zh") - - elif source_language=="chinese" and target_language == "english": - new_phns_origin = old_phns - new_phns_append, _ = sentence2phns(new_str_append, "en") - else: - assert target_language == "chinese" or target_language == "english", "cloning is not support for this language, please check it." - - else: - if source_language == target_language and target_language == "english": - new_phns_origin = old_phns - new_phns_append, _ = sentence2phns(new_str_append, "en") - - elif source_language == target_language and target_language == "chinese": - new_phns_origin = old_phns - new_phns_append, _ = sentence2phns(new_str_append, "zh") - else: - assert source_language == target_language, "source language is not same with target language..." - - if target_language == "chinese": - new_phns_append = gen_phns(zh_mapping, new_phns_append) - - new_phns = new_phns_origin + new_phns_append - - span_tobe_replaced = [len(old_phns),len(old_phns)] - span_tobe_added = [len(old_phns),len(new_phns)] - - else: - if source_language == target_language and target_language == "english": - new_phns, _ = sentence2phns(new_str, "en") - # 纯中文 - elif source_language == target_language and target_language == "chinese": - new_phns, _ = sentence2phns(new_str, "zh") - new_phns = gen_phns(zh_mapping, new_phns) - - - else: - assert source_language == target_language, "source language is not same with target language..." - - while(new_phns[-1] == 'sp'): - new_phns.pop() - - while(new_phns[0] == 'sp'): - new_phns.pop(0) - - span_tobe_replaced = [0,len(old_phns)-1] - span_tobe_added = [0,len(new_phns)-1] - new_phns_left = [] - left_index = 0 - sp_count = 0 - - # find the left different index - for idx, phn in enumerate(old_phns): - if phn == "sp": - sp_count += 1 - new_phns_left.append('sp') - else: - idx = idx - sp_count - if phn == new_phns[idx]: - left_index += 1 - new_phns_left.append(phn) - else: - span_tobe_replaced[0] = len(new_phns_left) - span_tobe_added[0] = len(new_phns_left) - break - - right_index = 0 - new_phns_middle = [] - new_phns_right = [] - sp_count = 0 - word2phns_max_index = len(old_phns) - new_word2phns_max_index = len(new_phns) - - for idx, phn in enumerate(old_phns[::-1]): - cur_idx = len(old_phns) - 1 - idx - if phn == "sp": - sp_count += 1 - new_phns_right = ['sp']+new_phns_right - else: - cur_idx = new_word2phns_max_index - (word2phns_max_index - cur_idx -sp_count) - if phn == new_phns[cur_idx]: - right_index -= 1 - new_phns_right = [phn] + new_phns_right - - else: - span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) - new_phns_middle = new_phns[left_index:right_index] - span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) - if len(new_phns_middle) == 0: - span_tobe_added[1] = min(span_tobe_added[1]+1, len(new_phns)) - span_tobe_added[0] = max(0, span_tobe_added[0]-1) - span_tobe_replaced[0] = max(0, span_tobe_replaced[0]-1) - span_tobe_replaced[1] = min(span_tobe_replaced[1]+1, len(old_phns)) - break - - new_phns = new_phns_left+new_phns_middle+new_phns_right - - return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added - - - - -def duration_adjust_factor(original_dur, pred_dur, phns): - length = 0 - accumulate = 0 - factor_list = [] - for ori,pred,phn in zip(original_dur, pred_dur,phns): - if pred==0 or phn=='sp': - continue - else: - factor_list.append(ori/pred) - factor_list = np.array(factor_list) - factor_list.sort() - if len(factor_list)<5: - return 1 - length = 2 - return np.average(factor_list[length:-length]) - - -def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, new_str, wav_path,duration_preditor_path,sid=None, mask_reconstruct=False,duration_adjust=True,start_end_sp=False, train_args=None): - wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) - fs = train_args.feats_extract_conf['fs'] - hop_length = train_args.feats_extract_conf['hop_length'] - - mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, target_language) - - if start_end_sp: - if new_phns[-1]!='sp': - new_phns = new_phns+['sp'] - - - if target_language == "english": - old_durations = evaluate_durations(old_phns, target_language=target_language) - - elif target_language =="chinese": - if source_language == "english": - old_durations = evaluate_durations(old_phns, target_language=source_language) - elif source_language == "chinese": - old_durations = evaluate_durations(old_phns, target_language=source_language) - - else: - assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." - - - - original_old_durations = [e-s for e,s in zip(mfa_end, mfa_start)] - if '[MASK]' in new_str: - new_phns = old_phns - span_tobe_added = span_tobe_replaced - d_factor_left = duration_adjust_factor(original_old_durations[:span_tobe_replaced[0]],old_durations[:span_tobe_replaced[0]], old_phns[:span_tobe_replaced[0]]) - d_factor_right = duration_adjust_factor(original_old_durations[span_tobe_replaced[1]:],old_durations[span_tobe_replaced[1]:], old_phns[span_tobe_replaced[1]:]) - d_factor = (d_factor_left+d_factor_right)/2 - new_durations_adjusted = [d_factor*i for i in old_durations] - else: - if duration_adjust: - d_factor = duration_adjust_factor(original_old_durations,old_durations, old_phns) - d_factor_paddle = duration_adjust_factor(original_old_durations,old_durations, old_phns) - if target_language =="chinese": - d_factor = d_factor * 1.35 - else: - d_factor = 1 - - if target_language == "english": - new_durations = evaluate_durations(new_phns, target_language=target_language) - - - elif target_language =="chinese": - new_durations = evaluate_durations(new_phns, target_language=target_language) - - new_durations_adjusted = [d_factor*i for i in new_durations] - - if span_tobe_replaced[0]=len(mfa_start): - left_index = len(wav_org) - right_index = left_index - else: - left_index = int(np.floor(mfa_start[span_tobe_replaced[0]]*fs)) - right_index = int(np.ceil(mfa_end[span_tobe_replaced[1]-1]*fs)) - new_blank_wav = np.zeros((int(np.ceil(new_span_duration_sum*fs)),), dtype=wav_org.dtype) - new_wav_org = np.concatenate([wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) - - - # 4. get old and new mel span to be mask - old_span_boundary = get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] - new_span_boundary=get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, hop_length, span_tobe_added) # [92, 174] - - - return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary - -def prepare_features(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor, wav_path, old_str,new_str,duration_preditor_path, sid=None,duration_adjust=True,start_end_sp=False, -mask_reconstruct=False, train_args=None): - wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, - new_str, wav_path,duration_preditor_path,sid=sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp,mask_reconstruct=mask_reconstruct, train_args = train_args) - speech = np.array(wav_org,dtype=np.float32) - align_start=np.array(mfa_start) - align_end =np.array(mfa_end) - token_to_id = {item: i for i, item in enumerate(train_args.token_list)} - text = np.array(list(map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) - print('unk id is', token_to_id['']) - # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) - span_boundary = np.array(new_span_boundary) - batch=[('1', {"speech":speech,"align_start":align_start,"align_end":align_end,"text":text,"span_boundary":span_boundary})] - - return batch, old_span_boundary, new_span_boundary - -def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False,duration_adjust=True,start_end_sp=False, train_args=None): - # fs, hop_length = mlm_model.feats_extract.fs, mlm_model.feats_extract.hop_length - fs, hop_length = train_args.feats_extract_conf['fs'], train_args.feats_extract_conf['hop_length'] - - batch,old_span_boundary,new_span_boundary = prepare_features(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor,wav_path,old_str,new_str,duration_preditor_path, sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args=train_args) - - feats = pickle.load(open('tmp/tmp_pkl.'+str(uid), 'rb')) - - # wav_len * 80 - # set_all_random_seed(9999) - if 'text_masked_position' in feats.keys(): - feats.pop('text_masked_position') - for k, v in feats.items(): - feats[k] = paddle.to_tensor(v) - rtn = mlm_model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=use_teacher_forcing) - output = rtn['feat_gen'] - if 0 in output[0].shape and 0 not in output[-1].shape: - output_feat = paddle.concat(output[1:-1]+[output[-1].squeeze()], axis=0).cpu() - elif 0 not in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat([output[0].squeeze()]+output[1:-1], axis=0).cpu() - elif 0 in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat(output[1:-1], axis=0).cpu() - else: - output_feat = paddle.concat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], axis=0).cpu() - - - # wav_org, rate = soundfile.read( - # wav_path, always_2d=False) - wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) - origin_speech = paddle.to_tensor(np.array(wav_org,dtype=np.float32)).unsqueeze(0) - speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0) - # input_feat, feats_lengths = mlm_model.feats_extract(origin_speech, speech_lengths) - # return wav_org, input_feat.squeeze(), output_feat, old_span_boundary, new_span_boundary, fs, hop_length - return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length - -class MLMCollateFn: - """Functor class of common_collate_fn()""" - - def __init__( - self, - feats_extract, - float_pad_value: Union[float, int] = 0.0, - int_pad_value: int = -32768, - not_sequence: Collection[str] = (), - mlm_prob: float=0.8, - mean_phn_span: int=8, - attention_window: int=0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False - - ): - self.mlm_prob=mlm_prob - self.mean_phn_span=mean_phn_span - self.feats_extract = feats_extract - self.float_pad_value = float_pad_value - self.int_pad_value = int_pad_value - self.not_sequence = set(not_sequence) - self.attention_window=attention_window - self.pad_speech=pad_speech - self.sega_emb=sega_emb - self.duration_collect = duration_collect - self.text_masking = text_masking - - def __repr__(self): - return ( - f"{self.__class__}(float_pad_value={self.float_pad_value}, " - f"int_pad_value={self.float_pad_value})" - ) - - def __call__( - self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] - ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - return mlm_collate_fn( - data, - float_pad_value=self.float_pad_value, - int_pad_value=self.int_pad_value, - not_sequence=self.not_sequence, - mlm_prob=self.mlm_prob, - mean_phn_span=self.mean_phn_span, - feats_extract=self.feats_extract, - attention_window=self.attention_window, - pad_speech=self.pad_speech, - sega_emb=self.sega_emb, - duration_collect=self.duration_collect, - text_masking=self.text_masking - ) - -def pad_list(xs, pad_value): - """Perform padding for the list of tensors. - - Args: - xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. - pad_value (float): Value for padding. - - Returns: - Tensor: Padded tensor (B, Tmax, `*`). - - Examples: - >>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) - - """ - n_batch = len(xs) - max_len = max(paddle.shape(x)[0] for x in xs) - pad = paddle.full((n_batch, max_len), pad_value, dtype = xs[0].dtype) - - for i in range(n_batch): - pad[i, : paddle.shape(xs[i])[0]] = xs[i] - - return pad - -def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window): - round = max_len % attention_window - if round != 0: - max_tlen += (attention_window - round) - n_batch = paddle.shape(text)[0] - text_pad = paddle.zeros((n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) - for i in range(n_batch): - text_pad[i, : paddle.shape(text[i])[0]] = text[i] - else: - text_pad = text[:, : max_tlen] - return text_pad, max_tlen - -def make_pad_mask(lengths, xs=None, length_dim=-1): - print('inputs are:', lengths, xs, length_dim) - """Make mask tensor containing indices of padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - Tensor: Mask tensor containing indices of padded part. - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - - With the reference tensor. - - >>> xs = paddle.zeros((3, 2, 4)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0], - [0, 0, 0, 0]], - [[0, 0, 0, 1], - [0, 0, 0, 1]], - [[0, 0, 1, 1], - [0, 0, 1, 1]]], dtype=paddle.uint8) - >>> xs = paddle.zeros((3, 2, 6)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=paddle.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = paddle.zeros((3, 6, 6)) - >>> make_pad_mask(lengths, xs, 1) - tensor([[[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]]], dtype=paddle.uint8) - >>> make_pad_mask(lengths, xs, 2) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=paddle.uint8) - - """ - if length_dim == 0: - raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - - if not isinstance(lengths, list): - lengths = list(lengths) - print('lengths', lengths) - bs = int(len(lengths)) - if xs is None: - maxlen = int(max(lengths)) - else: - maxlen = paddle.shape(xs)[length_dim] - - seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) - seq_range_expand = paddle.expand(paddle.unsqueeze(seq_range, 0), (bs, maxlen)) - seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1) - print('seq_length_expand', paddle.shape(seq_length_expand)) - print('seq_range_expand', paddle.shape(seq_range_expand)) - mask = seq_range_expand >= seq_length_expand - - if xs is not None: - assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs) - - if length_dim < 0: - length_dim = len(paddle.shape(xs)) + length_dim - # ind = (:, None, ..., None, :, , None, ..., None) - ind = tuple( - slice(None) if i in (0, length_dim) else None for i in range(len(paddle.shape(xs))) - ) - print('0:', paddle.shape(mask)) - print('1:', paddle.shape(mask[ind])) - print('2:', paddle.shape(xs)) - mask = paddle.expand(mask[ind], paddle.shape(xs)) - return mask - - -def make_non_pad_mask(lengths, xs=None, length_dim=-1): - """Make mask tensor containing indices of non-padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - ByteTensor: mask tensor containing indices of padded part. - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] - - With the reference tensor. - - >>> xs = paddle.zeros((3, 2, 4)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1], - [1, 1, 1, 1]], - [[1, 1, 1, 0], - [1, 1, 1, 0]], - [[1, 1, 0, 0], - [1, 1, 0, 0]]], dtype=paddle.uint8) - >>> xs = paddle.zeros((3, 2, 6)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=paddle.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = paddle.zeros((3, 6, 6)) - >>> make_non_pad_mask(lengths, xs, 1) - tensor([[[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]]], dtype=paddle.uint8) - >>> make_non_pad_mask(lengths, xs, 2) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=paddle.uint8) - - """ - return ~make_pad_mask(lengths, xs, length_dim) - -def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary=None): - bz, sent_len, _ = paddle.shape(xs_pad) - mask_num_lower = math.ceil(sent_len * mlm_prob) - masked_position = np.zeros((bz, sent_len)) - y_masks = None - # y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype) - # tril_masks = torch.tril(y_masks) - if mlm_prob == 1.0: - masked_position += 1 - # y_masks = tril_masks - elif mean_phn_span == 0: - # only speech - length = sent_len - mean_phn_span = min(length*mlm_prob//3, 50) - masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() - masked_position[:,masked_phn_indices]=1 - else: - for idx in range(bz): - if span_boundary is not None: - for s,e in zip(span_boundary[idx][::2], span_boundary[idx][1::2]): - masked_position[idx, s:e] = 1 - - # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] - # y_masks[idx, e:, s:e ] = 0 - else: - length = align_start_lengths[idx].item() - if length<2: - continue - masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() - masked_start = align_start[idx][masked_phn_indices].tolist() - masked_end = align_end[idx][masked_phn_indices].tolist() - for s,e in zip(masked_start, masked_end): - masked_position[idx, s:e] = 1 - # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] - # y_masks[idx, e:, s:e ] = 0 - non_eos_mask = np.array(paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu()) - masked_position = masked_position * non_eos_mask - # y_masks = src_mask & y_masks.bool() - - return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks - -def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb): - bz, speech_len, _ = speech_pad.size() - text_segment_pos = paddle.zeros_like(text_pad) - speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype) - if not sega_emb: - return speech_segment_pos, text_segment_pos - for idx in range(bz): - align_length = align_start_lengths[idx].item() - for j in range(align_length): - s,e = align_start[idx][j].item(), align_end[idx][j].item() - speech_segment_pos[idx][s:e] = j+1 - text_segment_pos[idx][j] = j+1 - - return speech_segment_pos, text_segment_pos - -def mlm_collate_fn( - data: Collection[Tuple[str, Dict[str, np.ndarray]]], - float_pad_value: Union[float, int] = 0.0, - int_pad_value: int = -32768, - not_sequence: Collection[str] = (), - mlm_prob: float = 0.8, - mean_phn_span: int = 8, - feats_extract=None, - attention_window: int = 0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False -) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - """Concatenate ndarray-list to an array and convert to paddle.Tensor. - - Examples: - >>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler, - >>> import espnet2.tasks.abs_task - >>> from espnet2.train.dataset import ESPnetDataset - >>> sampler = ConstantBatchSampler(...) - >>> dataset = ESPnetDataset(...) - >>> keys = next(iter(sampler) - >>> batch = [dataset[key] for key in keys] - >>> batch = common_collate_fn(batch) - >>> model(**batch) - - Note that the dict-keys of batch are propagated from - that of the dataset as they are. - - """ - uttids = [u for u, _ in data] - data = [d for _, d in data] - - assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" - assert all( - not k.endswith("_lengths") for k in data[0] - ), f"*_lengths is reserved: {list(data[0])}" - - output = {} - for key in data[0]: - # NOTE(kamo): - # Each models, which accepts these values finally, are responsible - # to repaint the pad_value to the desired value for each tasks. - if data[0][key].dtype.kind == "i": - pad_value = int_pad_value - else: - pad_value = float_pad_value - - array_list = [d[key] for d in data] - - # Assume the first axis is length: - # tensor_list: Batch x (Length, ...) - tensor_list = [paddle.to_tensor(a) for a in array_list] - # tensor: (Batch, Length, ...) - tensor = pad_list(tensor_list, pad_value) - output[key] = tensor - - # lens: (Batch,) - if key not in not_sequence: - lens = paddle.to_tensor([d[key].shape[0] for d in data], dtype=paddle.long) - output[key + "_lengths"] = lens - - f = open('tmp_var.out', 'w') - for item in [round(item, 6) for item in output["speech"][0].tolist()]: - f.write(str(item)+'\n') - feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) - feats = paddle.to_tensor(feats) - print('out shape', paddle.shape(feats)) - feats_lengths = paddle.shape(feats)[0] - feats = paddle.unsqueeze(feats, 0) - batch_size = paddle.shape(feats)[0] - if 'text' not in output: - text=paddle.zeros_like(feats_lengths.unsqueeze(-1))-2 - text_lengths=paddle.zeros_like(feats_lengths)+1 - max_tlen=1 - align_start=paddle.zeros_like(text) - align_end=paddle.zeros_like(text) - align_start_lengths=paddle.zeros_like(feats_lengths) - align_end_lengths=paddle.zeros_like(feats_lengths) - sega_emb=False - mean_phn_span = 0 - mlm_prob = 0.15 - else: - text, text_lengths = output["text"], output["text_lengths"] - align_start, align_start_lengths, align_end, align_end_lengths = output["align_start"], output["align_start_lengths"], output["align_end"], output["align_end_lengths"] - align_start = paddle.floor(feats_extract.sr*align_start/feats_extract.hop_length).int() - align_end = paddle.floor(feats_extract.sr*align_end/feats_extract.hop_length).int() - max_tlen = max(text_lengths).item() - max_slen = max(feats_lengths).item() - speech_pad = feats[:, : max_slen] - if attention_window>0 and pad_speech: - speech_pad,max_slen = pad_to_longformer_att_window(speech_pad, max_slen, max_slen, attention_window) - max_len = max_slen + max_tlen - if attention_window>0: - text_pad, max_tlen = pad_to_longformer_att_window(text, max_len, max_tlen, attention_window) - else: - text_pad = text - text_mask = make_non_pad_mask(text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) - if attention_window>0: - text_mask = text_mask*2 - speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:,0], length_dim=1).unsqueeze(-2) - span_boundary = None - if 'span_boundary' in output.keys(): - span_boundary = output['span_boundary'] - - if text_masking: - masked_position, text_masked_position,_ = phones_text_masking( - speech_pad, - speech_mask, - text_pad, - text_mask, - align_start, - align_end, - align_start_lengths, - mlm_prob, - mean_phn_span, - span_boundary) - else: - text_masked_position = np.zeros(text_pad.size()) - masked_position, _ = phones_masking( - speech_pad, - speech_mask, - align_start, - align_end, - align_start_lengths, - mlm_prob, - mean_phn_span, - span_boundary) - - output_dict = {} - if duration_collect and 'text' in output: - reordered_index, speech_segment_pos,text_segment_pos, durations,feats_lengths = get_segment_pos_reduce_duration(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb, masked_position, feats_lengths) - speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:reordered_index.shape[1],0], length_dim=1).unsqueeze(-2) - output_dict['durations'] = durations - output_dict['reordered_index'] = reordered_index - else: - speech_segment_pos, text_segment_pos = get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb) - output_dict['speech'] = speech_pad - output_dict['text'] = text_pad - output_dict['masked_position'] = masked_position - output_dict['text_masked_position'] = text_masked_position - output_dict['speech_mask'] = speech_mask - output_dict['text_mask'] = text_mask - output_dict['speech_segment_pos'] = speech_segment_pos - output_dict['text_segment_pos'] = text_segment_pos - # output_dict['y_masks'] = y_masks - output_dict['speech_lengths'] = output["speech_lengths"] - output_dict['text_lengths'] = text_lengths - output = (uttids, output_dict) - # assert check_return_type(output) - return output - -def build_collate_fn( - args: argparse.Namespace, train: bool, epoch=-1 - ): - - # assert check_argument_types() - # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) - feats_extract_class = LogMelFBank - args_dic = {} - print ('type is', type(args.feats_extract_conf)) - for k, v in args.feats_extract_conf.items(): - if k == 'fs': - args_dic['sr'] = v - else: - args_dic[k] = v - # feats_extract = feats_extract_class(**args.feats_extract_conf) - feats_extract = feats_extract_class(**args_dic) - - sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False - if args.encoder_conf['selfattention_layer_type'] == 'longformer': - attention_window = args.encoder_conf['attention_window'] - pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf['pre_speech_layer'] >0 else False - else: - attention_window=0 - pad_speech=False - if epoch==-1: - mlm_prob_factor = 1 - else: - mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5] - mlm_prob_factor = 0.8 #mlm_probs[epoch // 100] - if 'duration_predictor_layers' in args.model_conf.keys() and args.model_conf['duration_predictor_layers']>0: - duration_collect=True - else: - duration_collect=False - return MLMCollateFn(feats_extract, float_pad_value=0.0, int_pad_value=0, - mlm_prob=args.model_conf['mlm_prob']*mlm_prob_factor,mean_phn_span=args.model_conf['mean_phn_span'],attention_window=attention_window,pad_speech=pad_speech,sega_emb=sega_emb,duration_collect=duration_collect) - - -def get_mlm_output(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False, dynamic_eval=(0,0),duration_adjust=True,start_end_sp=False): - mlm_model,train_args = load_model(model_name) - mlm_model.eval() - # processor = MLMTask.build_preprocess_fn(train_args, False) - processor = None - collate_fn = MLMTask.build_collate_fn(train_args, False) - # collate_fn = build_collate_fn(train_args, False) - - return decode_with_model(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing, - duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args = train_args) - -def prompt_decoding_fn(model_name, wav_path,full_origin_str, old_str, new_str, vocoder,duration_preditor_path,sid=None, non_autoreg=True, dynamic_eval=(0,0),duration_adjust=True): - wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( - model_name, - wav_path, - old_str, - new_str, - duration_preditor_path, - use_teacher_forcing=non_autoreg, - sid=sid, - dynamic_eval=dynamic_eval, - duration_adjust=duration_adjust, - start_end_sp=False - ) - - replaced_wav = vocoder(output_feat).detach().float().data.cpu().numpy() - - old_time_boundary = [hop_length * x for x in old_span_boundary] - new_time_boundary = [hop_length * x for x in new_span_boundary] - new_wav = replaced_wav[new_time_boundary[0]:] - # "origin_vocoder":vocoder_origin_wav, - data_dict = {"prompt":wav_org, - "new_wav":new_wav} - return data_dict - -def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vocoder, prefix='dump/raw/dev', model_name="conformer", old_str="",new_str="",prompt_decoding=False,dynamic_eval=(0,0), task_name = None): - - new_str = new_str.strip() - if clone_uid is not None and clone_prefix is not None: - if target_language == "english": - duration_preditor_path = duration_path_dict['ljspeech'] - elif target_language == "chinese": - duration_preditor_path = duration_path_dict['ljspeech'] - else: - assert target_language == "chinese" or target_language == "english", "duration_preditor_path is not support for this language..." - - else: - duration_preditor_path = duration_path_dict['ljspeech'] - - spemd = None - full_origin_str,wav_path = read_data(uid, prefix) - - new_str = new_str if task_name == 'edit' else full_origin_str + new_str - print('new_str is ', new_str) - - if not old_str: - old_str = full_origin_str - if not new_str: - new_str = input("input the new string:") - if prompt_decoding: - print(new_str) - return prompt_decoding_fn(model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd,dynamic_eval=dynamic_eval) - print(full_origin_str) - results_dict, old_span = plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd) - return results_dict - -if __name__ == "__main__": - args = parse_args() - print(args) - data_dict = test_vctk(args.uid, - args.clone_uid, - args.clone_prefix, - args.source_language, - args.target_language, - None, - args.prefix, - args.model_name, - new_str=args.new_str, - task_name=args.task_name) - sf.write('./wavs/%s' % args.output_name, data_dict['output'], samplerate=24000) - \ No newline at end of file diff --git a/ernie-sat/tmp/tmp_pkl.Prompt_003_new b/ernie-sat/tmp/tmp_pkl.Prompt_003_new index c7432dac7c6f3261a37ffc994f852dd7ea89fcd8..9883958b8dad23c61e783d30cc18044e5e3c75d1 100644 Binary files a/ernie-sat/tmp/tmp_pkl.Prompt_003_new and b/ernie-sat/tmp/tmp_pkl.Prompt_003_new differ diff --git a/ernie-sat/tmp/tmp_pkl.p243_new b/ernie-sat/tmp/tmp_pkl.p243_new index 33075eb1cf1c5ab94bd85538409ed0c23cc90e0a..c4d88f1de7881058dcc8e6c6de1550c48f9d0181 100644 Binary files a/ernie-sat/tmp/tmp_pkl.p243_new and b/ernie-sat/tmp/tmp_pkl.p243_new differ diff --git a/ernie-sat/tmp/tmp_pkl.p299_096 b/ernie-sat/tmp/tmp_pkl.p299_096 index c0553e427ffa71571d819f7b9879f062e7bb68fe..5b88d59a83dcbbe666efd517e0d31e2f3df32e6c 100644 Binary files a/ernie-sat/tmp/tmp_pkl.p299_096 and b/ernie-sat/tmp/tmp_pkl.p299_096 differ diff --git a/ernie-sat/util.py b/ernie-sat/utils.py similarity index 97% rename from ernie-sat/util.py rename to ernie-sat/utils.py index 45f10e7da157b9e8646c8318ebf6a6f72473944c..4c9b6fc56d0b9867c98944646a0120e831e239b8 100644 --- a/ernie-sat/util.py +++ b/ernie-sat/utils.py @@ -70,13 +70,15 @@ def get_voc_out(mel, target_language="chinese"): print("current vocoder: ", args.voc) with open(args.voc_config) as f: voc_config = CfgNode(yaml.safe_load(f)) + # print(voc_config) voc_inference = get_voc_inference(args, voc_config) mel = paddle.to_tensor(mel) + # print("masked_mel: ", mel.shape) with paddle.no_grad(): wav = voc_inference(mel) - print("shepe of wav (time x n_channels):%s"%wav.shape) # (31800,1) + # print("shepe of wav (time x n_channels):%s"%wav.shape) return np.squeeze(wav) # dygraph @@ -134,6 +136,7 @@ def get_am_inference(args, am_config): def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300): args = parse_args() + # args = parser.parse_args(args=[]) if args.ngpu == 0: paddle.set_device("cpu") elif args.ngpu > 0: @@ -154,6 +157,7 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 # acoustic model am, am_inference, am_name, am_dataset,phn_id = get_am_inference(args, am_config) + torch_phns = phns vocab_phones = {} for tone, id in phn_id: @@ -165,17 +169,16 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 ] phone_ids = [vocab_phones[item] for item in phonemes] phone_ids_new = phone_ids - phone_ids_new.append(vocab_size-1) phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) normalized_mel, d_outs, p_outs, e_outs = am.inference(phone_ids_new, spk_id=None, spk_emb=None) pre_d_outs = d_outs phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = phoneme_durations_new.tolist()[:-1] - return phoneme_durations_new + def sentence2phns(sentence, target_language="en"): args = parse_args() if target_language == 'en': diff --git a/ernie-sat/wavs/ori.wav b/ernie-sat/wavs/ori.wav deleted file mode 100644 index d50fcb59b8c4828138648e57c4dad168b874b4ca..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/ori.wav and /dev/null differ diff --git a/ernie-sat/wavs/pred.wav b/ernie-sat/wavs/pred.wav index 0210a420898d81f5d90068a8691572c30df25cb3..5b45b642069229550e758f93020a1f126f17e90e 100644 Binary files a/ernie-sat/wavs/pred.wav and b/ernie-sat/wavs/pred.wav differ diff --git a/ernie-sat/wavs/pred_en_edit_paddle_voc.wav b/ernie-sat/wavs/pred_en_edit_paddle_voc.wav index 8a05b71046ee44808bdc0228277610f686ef81ca..4c6a7ef103d6dfbb670a06844da4146387ba2cad 100644 Binary files a/ernie-sat/wavs/pred_en_edit_paddle_voc.wav and b/ernie-sat/wavs/pred_en_edit_paddle_voc.wav differ diff --git a/ernie-sat/wavs/pred_zh.wav b/ernie-sat/wavs/pred_zh.wav index 124258b94eab1c608a7f95575c72d701894ab5b3..7f220184bf660df083be0898eb990165d3f63d68 100644 Binary files a/ernie-sat/wavs/pred_zh.wav and b/ernie-sat/wavs/pred_zh.wav differ diff --git a/ernie-sat/wavs/pred_zh_fst2_voc.wav b/ernie-sat/wavs/pred_zh_fst2_voc.wav deleted file mode 100644 index 57ce66e5a4f9a44800292202f1b27c08d72c1b99..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/pred_zh_fst2_voc.wav and /dev/null differ diff --git a/ernie-sat/wavs/task_cross_lingual_pred.wav b/ernie-sat/wavs/task_cross_lingual_pred.wav deleted file mode 100644 index cffebaf81c7bb95d3ad65272826efc11fd04fb01..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/task_cross_lingual_pred.wav and /dev/null differ diff --git a/ernie-sat/wavs/task_edit_pred.wav b/ernie-sat/wavs/task_edit_pred.wav deleted file mode 100644 index 6bfda0fa42584cce5690ac86da095384200647d4..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/task_edit_pred.wav and /dev/null differ diff --git a/ernie-sat/wavs/task_synthesize_pred.wav b/ernie-sat/wavs/task_synthesize_pred.wav deleted file mode 100644 index ce1379919274e5a8cc113d09e73ed869db0bdd56..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/task_synthesize_pred.wav and /dev/null differ