diff --git a/examples/aishell3/ernie_sat/conf/default.yaml b/examples/aishell3/ernie_sat/conf/default.yaml index d724fe47c8988e24fa971f171caf988b0d853e23..d8993e86f4f4d4872f584db1f70b574275127a10 100644 --- a/examples/aishell3/ernie_sat/conf/default.yaml +++ b/examples/aishell3/ernie_sat/conf/default.yaml @@ -21,7 +21,7 @@ mlm_prob: 0.8 ########################################################### # DATA SETTING # ########################################################### -batch_size: 64 +batch_size: 20 num_workers: 2 ########################################################### @@ -71,14 +71,15 @@ model: ########################################################### # OPTIMIZER SETTING # ########################################################### -optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 200 +max_epoch: 600 num_snapshots: 5 ########################################################### diff --git a/examples/aishell3_vctk/ernie_sat/conf/default.yaml b/examples/aishell3_vctk/ernie_sat/conf/default.yaml index 8c9cd7e024c285c46fd05ee0a2244336e123a8ee..745a5b840549e5e07deffcd14546808b5591f5ed 100644 --- a/examples/aishell3_vctk/ernie_sat/conf/default.yaml +++ b/examples/aishell3_vctk/ernie_sat/conf/default.yaml @@ -21,7 +21,7 @@ mlm_prob: 0.8 ########################################################### # DATA SETTING # ########################################################### -batch_size: 64 +batch_size: 20 num_workers: 2 ########################################################### @@ -71,14 +71,15 @@ model: ########################################################### # OPTIMIZER SETTING # ########################################################### -optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 100 +max_epoch: 300 num_snapshots: 5 ########################################################### diff --git a/examples/aishell3_vctk/ernie_sat/local/preprocess.sh b/examples/aishell3_vctk/ernie_sat/local/preprocess.sh index 0cdc7bcafb00a4cbdd940d11f721f47b705487c2..783fd6333340ec4028a1fe7063c0335231ce0f98 100755 --- a/examples/aishell3_vctk/ernie_sat/local/preprocess.sh +++ b/examples/aishell3_vctk/ernie_sat/local/preprocess.sh @@ -7,14 +7,29 @@ config_path=$1 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # get durations from MFA's result - echo "Generate durations.txt from MFA results ..." + echo "Generate durations.txt from MFA results for aishell3 ..." python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ --inputdir=./aishell3_alignment_tone \ - --output durations.txt \ + --output durations_aishell3.txt \ --config=${config_path} fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results for vctk ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./vctk_alignment \ + --output durations_vctk.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get durations from MFA's result + echo "concat durations_aishell3.txt and durations_vctk.txt to durations.txt" + cat durations_aishell3.txt durations_vctk.txt > durations.txt +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # extract features echo "Extract features ..." python3 ${BIN_DIR}/preprocess.py \ @@ -27,7 +42,20 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --cut-sil=True fi -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=vctk \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # get features' stats(mean and std) echo "Get features' stats ..." python3 ${MAIN_ROOT}/utils/compute_statistics.py \ @@ -35,15 +63,13 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --field-name="speech" fi -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # normalize and covert phone/speaker to id, dev and test should use train's stats echo "Normalize ..." python3 ${BIN_DIR}/normalize.py \ --metadata=dump/train/raw/metadata.jsonl \ --dumpdir=dump/train/norm \ --speech-stats=dump/train/speech_stats.npy \ - --pitch-stats=dump/train/pitch_stats.npy \ - --energy-stats=dump/train/energy_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt @@ -51,8 +77,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --metadata=dump/dev/raw/metadata.jsonl \ --dumpdir=dump/dev/norm \ --speech-stats=dump/train/speech_stats.npy \ - --pitch-stats=dump/train/pitch_stats.npy \ - --energy-stats=dump/train/energy_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt @@ -60,8 +84,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --metadata=dump/test/raw/metadata.jsonl \ --dumpdir=dump/test/norm \ --speech-stats=dump/train/speech_stats.npy \ - --pitch-stats=dump/train/pitch_stats.npy \ - --energy-stats=dump/train/energy_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt fi diff --git a/examples/aishell3_vctk/ernie_sat/path.sh b/examples/aishell3_vctk/ernie_sat/path.sh index d46d2f612a135fa73a3be303060163e53840b4c6..4ecab02517d5a4bd59baf95bb5f536e263ce7ac0 100755 --- a/examples/aishell3_vctk/ernie_sat/path.sh +++ b/examples/aishell3_vctk/ernie_sat/path.sh @@ -1,5 +1,5 @@ #!/bin/bash -export MAIN_ROOT=`realpath ${PWD}/../../` +export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C diff --git a/examples/vctk/ernie_sat/local/synthesize.sh b/examples/vctk/ernie_sat/local/synthesize.sh index cc1f786e84631faabc68d86a3aefffbd1ae03a06..b24db018ac000aa7bbe1fd04b4d7c6fd5930eb5d 100755 --- a/examples/vctk/ernie_sat/local/synthesize.sh +++ b/examples/vctk/ernie_sat/local/synthesize.sh @@ -1 +1,45 @@ -#!/bin/bash \ No newline at end of file +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=1 +stop_stage=1 + +# use am to predict duration here +# 增加 am_phones_dict am_tones_dict 等,也可以用新的方式构造 am, 不需要这么多参数了就 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=pwgan_vctk \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=hifigan_vctk \ + --voc_config=hifigan_vctk_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_vctk_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_vctk_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index fbc03c0de860ef2f51fcd5c8bfcf30f69f93df0f..9c964d8e95396249a72f5a2537847e09f53b4c70 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -119,9 +119,17 @@ def erniesat_batch_fn(examples, speech_mask = make_non_pad_mask( speech_lengths, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + # for training + span_bdy = None + # for inference + if 'span_bdy' in examples[0].keys(): + span_bdy = [ + np.array(item["span_bdy"], dtype=np.int64) for item in examples + ] + span_bdy = paddle.to_tensor(span_bdy) + # dual_mask 的是混合中英时候同时 mask 语音和文本 # ernie sat 在实现跨语言的时候都 mask 了 - span_bdy = None if text_masking: masked_pos, text_masked_pos = phones_text_masking( xs_pad=speech_pad, diff --git a/paddlespeech/t2s/exps/ernie_sat/preprocess.py b/paddlespeech/t2s/exps/ernie_sat/preprocess.py index a56314a264c623e6fa344538145376cb71fa950a..fc9e0888b262d79ddba96ebaa095feeb1b1ab315 100644 --- a/paddlespeech/t2s/exps/ernie_sat/preprocess.py +++ b/paddlespeech/t2s/exps/ernie_sat/preprocess.py @@ -166,7 +166,8 @@ def process_sentences(config, results.append(record) results.sort(key=itemgetter("utt_id")) - with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + # replace 'w' with 'a' to write from the end of file + with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer: for item in results: writer.write(item) print("Done") diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize.py b/paddlespeech/t2s/exps/ernie_sat/synthesize.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..56f26a8bbd04f9b7e4023347ed82411d32ff7b07 100644 --- a/paddlespeech/t2s/exps/ernie_sat/synthesize.py +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize.py @@ -0,0 +1,201 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_test_dataset +from paddlespeech.t2s.exps.syn_utils import get_voc_inference + + +def evaluate(args): + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + + # Init body. + with open(args.erniesat_config) as f: + erniesat_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(erniesat_config) + print(voc_config) + + # ernie sat model + erniesat_inference = get_am_inference( + am='erniesat_dataset', + am_config=erniesat_config, + am_ckpt=args.erniesat_ckpt, + am_stat=args.erniesat_stat, + phones_dict=args.phones_dict) + + test_dataset = get_test_dataset( + test_metadata=test_metadata, am='erniesat_dataset') + + # vocoder + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + collate_fn = build_erniesat_collate_fn( + mlm_prob=erniesat_config.mlm_prob, + mean_phn_span=erniesat_config.mean_phn_span, + seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', + text_masking=False, + epoch=-1) + + gen_raw = True + erniesat_mu, erniesat_std = np.load(args.erniesat_stat) + + for datum in test_dataset: + # collate function and dataloader + utt_id = datum["utt_id"] + speech_len = datum["speech_lengths"] + + # mask the middle 1/3 speech + left_bdy, right_bdy = speech_len // 3, 2 * speech_len // 3 + span_bdy = [left_bdy, right_bdy] + datum.update({"span_bdy": span_bdy}) + + batch = collate_fn([datum]) + with paddle.no_grad(): + out_mels = erniesat_inference( + speech=batch["speech"], + text=batch["text"], + masked_pos=batch["masked_pos"], + speech_mask=batch["speech_mask"], + text_mask=batch["text_mask"], + speech_seg_pos=batch["speech_seg_pos"], + text_seg_pos=batch["text_seg_pos"], + span_bdy=span_bdy) + + # vocoder + wav_list = [] + for mel in out_mels: + part_wav = voc_inference(mel) + wav_list.append(part_wav) + wav = paddle.concat(wav_list) + wav = wav.numpy() + if gen_raw: + speech = datum['speech'] + denorm_mel = denorm(speech, erniesat_mu, erniesat_std) + denorm_mel = paddle.to_tensor(denorm_mel) + wav_raw = voc_inference(denorm_mel) + wav_raw = wav_raw.numpy() + + sf.write( + str(output_dir / (utt_id + ".wav")), + wav, + samplerate=erniesat_config.fs) + if gen_raw: + sf.write( + str(output_dir / (utt_id + "_raw" + ".wav")), + wav_raw, + samplerate=erniesat_config.fs) + + print(f"{utt_id} done!") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # ernie sat + + parser.add_argument( + '--erniesat_config', + type=str, + default=None, + help='Config of acoustic model.') + parser.add_argument( + '--erniesat_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--erniesat_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_aishell3', + 'pwgan_vctk', + 'hifigan_aishell3', + 'hifigan_vctk', + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', type=str, default=None, help='Config of voc.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + + args = parse_args() + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index 73354447692b771b1dc45f503a534d52eb58c96d..977b8fc52b426f2779a19225067717a4f5efbedb 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -62,8 +62,6 @@ def train_sp(args, config): "align_end" ] converters = {"speech": np.load} - spk_num = None - # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index cabea98972e147955317abf76aee069739eed49f..6005867aee6d5a5763ba5ed8766e69d8ae80c3ff 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -68,6 +68,10 @@ model_alias = { "paddlespeech.t2s.models.wavernn:WaveRNN", "wavernn_inference": "paddlespeech.t2s.models.wavernn:WaveRNNInference", + "erniesat": + "paddlespeech.t2s.models.ernie_sat:ErnieSAT", + "erniesat_inference": + "paddlespeech.t2s.models.ernie_sat:ErnieSATInference", } @@ -109,6 +113,7 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] + converters = {} if am_name == 'fastspeech2': fields = ["utt_id", "text"] if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: @@ -126,8 +131,17 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], if voice_cloning: print("voice cloning!") fields += ["spk_emb"] + elif am_name == 'erniesat': + fields = [ + "utt_id", "text", "text_lengths", "speech", "speech_lengths", + "align_start", "align_end" + ] + converters = {"speech": np.load} + else: + print("wrong am, please input right am!!!") - test_dataset = DataTable(data=test_metadata, fields=fields) + test_dataset = DataTable( + data=test_metadata, fields=fields, converters=converters) return test_dataset @@ -193,6 +207,10 @@ def get_am_inference(am: str='fastspeech2_csmsc', **am_config["model"]) elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + elif am_name == 'erniesat': + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + else: + print("wrong am, please input right am!!!") am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat.py index 9f7f8aa78911c97d24ba65ff898973cbaf08b446..54f5d542d035830f4f90c56daac659debe498d53 100644 --- a/paddlespeech/t2s/models/ernie_sat/ernie_sat.py +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat.py @@ -389,7 +389,7 @@ class MLM(nn.Layer): speech_seg_pos: paddle.Tensor, text_seg_pos: paddle.Tensor, span_bdy: List[int], - use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + use_teacher_forcing: bool=False, ) -> List[paddle.Tensor]: ''' Args: speech (paddle.Tensor): input speech (1, Tmax, D). @@ -668,3 +668,38 @@ class ErnieSAT(nn.Layer): text_seg_pos=text_seg_pos, span_bdy=span_bdy, use_teacher_forcing=use_teacher_forcing) + + +class ErnieSATInference(nn.Layer): + def __init__(self, normalizer, model): + super().__init__() + self.normalizer = normalizer + self.acoustic_model = model + + def forward( + self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], + use_teacher_forcing: bool=True, ): + outs = self.acoustic_model.inference( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos, + span_bdy=span_bdy, + use_teacher_forcing=use_teacher_forcing) + + normed_mel_pre, normed_mel_masked, normed_mel_post = outs + logmel_pre = self.normalizer.inverse(normed_mel_pre) + logmel_masked = self.normalizer.inverse(normed_mel_masked) + logmel_post = self.normalizer.inverse(normed_mel_post) + return logmel_pre, logmel_masked, logmel_post