From 4d7cd0e0638818bef61ebc52e535f545be43d590 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 31 Mar 2022 12:57:51 +0000 Subject: [PATCH] add streaming synthesize, test=tts --- .../csmsc/tts3/local/synthesize_streaming.sh | 18 +- paddlespeech/t2s/exps/synthesize_streaming.py | 269 ++++++++++++++++++ .../t2s/models/fastspeech2/fastspeech2.py | 49 +++- .../t2s/modules/transformer/encoder.py | 4 - 4 files changed, 316 insertions(+), 24 deletions(-) create mode 100644 paddlespeech/t2s/exps/synthesize_streaming.py diff --git a/examples/csmsc/tts3/local/synthesize_streaming.sh b/examples/csmsc/tts3/local/synthesize_streaming.sh index 69bb22df..7606c238 100755 --- a/examples/csmsc/tts3/local/synthesize_streaming.sh +++ b/examples/csmsc/tts3/local/synthesize_streaming.sh @@ -22,9 +22,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/test_e2e \ + --output_dir=${train_output_path}/test_e2e_streaming \ --phones_dict=dump/phone_id_map.txt \ - --inference_dir=${train_output_path}/inference + --am_streaming=True fi # for more GAN Vocoders @@ -43,9 +43,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/test_e2e \ + --output_dir=${train_output_path}/test_e2e_streaming \ --phones_dict=dump/phone_id_map.txt \ - --inference_dir=${train_output_path}/inference + --am_streaming=True fi # the pretrained models haven't release now @@ -65,9 +65,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt - # --inference_dir=${train_output_path}/inference + --output_dir=${train_output_path}/test_e2e_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True fi # hifigan @@ -86,7 +86,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/test_e2e \ + --output_dir=${train_output_path}/test_e2e_streaming \ --phones_dict=dump/phone_id_map.txt \ - --inference_dir=${train_output_path}/inference + --am_streaming=True fi diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py new file mode 100644 index 00000000..62915539 --- /dev/null +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -0,0 +1,269 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.utils import str2bool + + +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, chunk_size, pad_size): + data_len = data.shape[1] + chunks = [] + n = math.ceil(data_len / chunk_size) + for i in range(n): + start = max(0, i * chunk_size - pad_size) + end = min((i + 1) * chunk_size + pad_size, data_len) + chunks.append(data[:, start:end, :]) + return chunks + + +def evaluate(args): + + # Init body. + with open(args.am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(am_config) + print(voc_config) + + sentences = get_sentences(args) + + # frontend + frontend = get_frontend(args) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + # acoustic model, only support fastspeech2 here now! + # am_inference, am_name, am_dataset = get_am_inference(args, am_config) + # model: {model_name}_{dataset} + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + odim = am_config.n_mels + + am_class = dynamic_import(am_name, model_alias) + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(args.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + + # vocoder + voc_inference = get_voc_inference(args, voc_config) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = True + + N = 0 + T = 0 + chunk_size = 42 + pad_size = 12 + + for utt_id, sentence in sentences: + with timer() as t: + get_tone_ids = False + + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + + phone_ids = input_ids["phone_ids"] + else: + print("lang should in be 'zh' here!") + # merge_sentences=False here, so we only use the first item of phone_ids + phone_ids = phone_ids[0] + with paddle.no_grad(): + # acoustic model + orig_hs, h_masks = am.encoder_infer(phone_ids) + + if args.am_streaming: + hss = get_chunks(orig_hs, chunk_size, pad_size) + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + before_outs, _ = am.decoder(hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-pad_size] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[pad_size:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = paddle.concat(mel_list, axis=0) + + else: + before_outs, _ = am.decoder(orig_hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + mel = denorm(normalized_mel, am_mu, am_std) + + # vocoder + wav = voc_inference(mel) + + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = am_config.fs / speed + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write( + str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") + + +def parse_args(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=['fastspeech2_csmsc'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + '--am_config', + type=str, + default=None, + help='Config of acoustic model. Use deault config when it is None.') + parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_csmsc', + 'pwgan_ljspeech', + 'pwgan_aishell3', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'style_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_ljspeech', + 'hifigan_aishell3', + 'hifigan_vctk', + 'wavernn_csmsc', + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', + type=str, + default=None, + help='Config of voc. Use deault config when it is None.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 1c805051..c2f1e218 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -509,6 +509,7 @@ class FastSpeech2(nn.Layer): ps: paddle.Tensor=None, es: paddle.Tensor=None, is_inference: bool=False, + return_after_enc=False, alpha: float=1.0, spk_emb=None, spk_id=None, @@ -589,8 +590,10 @@ class FastSpeech2(nn.Layer): h_masks = self._source_mask(olens_in) else: h_masks = None - # (B, Lmax, adim) + if return_after_enc: + return hs, h_masks + # (B, Lmax, adim) zs, _ = self.decoder(hs, h_masks) # (B, Lmax, odim) if self.decoder_type == 'cnndecoder': @@ -608,10 +611,42 @@ class FastSpeech2(nn.Layer): return before_outs, after_outs, d_outs, p_outs, e_outs + def encoder_infer( + self, + text: paddle.Tensor, + alpha: float=1.0, + spk_emb=None, + spk_id=None, + tone_id=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # input of embedding must be int64 + x = paddle.cast(text, 'int64') + # setup batch axis + ilens = paddle.shape(x)[0] + + xs = x.unsqueeze(0) + + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) + + if tone_id is not None: + tone_id = tone_id.unsqueeze(0) + + # (1, L, odim) + hs, h_masks = self._forward( + xs, + ilens, + is_inference=True, + return_after_enc=True, + alpha=alpha, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id) + return hs, h_masks + def inference( self, text: paddle.Tensor, - speech: paddle.Tensor=None, durations: paddle.Tensor=None, pitch: paddle.Tensor=None, energy: paddle.Tensor=None, @@ -625,7 +660,6 @@ class FastSpeech2(nn.Layer): Args: text(Tensor(int64)): Input sequence of characters (T,). - speech(Tensor, optional): Feature sequence to extract style (N, idim). durations(Tensor, optional (int64)): Groundtruth of duration (T,). pitch(Tensor, optional): Groundtruth of token-averaged pitch (T, 1). energy(Tensor, optional): Groundtruth of token-averaged energy (T, 1). @@ -642,15 +676,11 @@ class FastSpeech2(nn.Layer): """ # input of embedding must be int64 x = paddle.cast(text, 'int64') - y = speech d, p, e = durations, pitch, energy # setup batch axis ilens = paddle.shape(x)[0] - xs, ys = x.unsqueeze(0), None - - if y is not None: - ys = y.unsqueeze(0) + xs = x.unsqueeze(0) if spk_emb is not None: spk_emb = spk_emb.unsqueeze(0) @@ -668,7 +698,6 @@ class FastSpeech2(nn.Layer): _, outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, - ys, ds=ds, ps=ps, es=es, @@ -681,7 +710,6 @@ class FastSpeech2(nn.Layer): _, outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, - ys, is_inference=True, alpha=alpha, spk_emb=spk_emb, @@ -829,7 +857,6 @@ class StyleFastSpeech2Inference(FastSpeech2Inference): Args: text(Tensor(int64)): Input sequence of characters (T,). - speech(Tensor, optional): Feature sequence to extract style (N, idim). durations(paddle.Tensor/np.ndarray, optional (int64)): Groundtruth of duration (T,), this will overwrite the set of durations_scale and durations_bias durations_scale(int/float, optional): durations_bias(int/float, optional): diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index 25a11ff6..f6420282 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -587,7 +587,6 @@ class CNNDecoder(nn.Layer): Returns: Tensor: Output tensor (#batch, time, odim). """ - # print("input.shape in CNNDecoder:",xs.shape) # exchange the temporal dimension and the feature dimension xs = xs.transpose([0, 2, 1]) if masks is not None: @@ -603,7 +602,6 @@ class CNNDecoder(nn.Layer): if masks is not None: outputs = outputs * masks outputs = outputs.transpose([0, 2, 1]) - # print("outputs.shape in CNNDecoder:",outputs.shape) return outputs, masks @@ -636,7 +634,6 @@ class CNNPostnet(nn.Layer): Returns: Tensor: Output tensor (#batch, odim, time). """ - # print("xs.shape in CNNPostnet:",xs.shape) for layer in self.residual_blocks: outputs = layer(xs) if masks is not None: @@ -646,5 +643,4 @@ class CNNPostnet(nn.Layer): outputs = self.conv1d(outputs) if masks is not None: outputs = outputs * masks - # print("outputs.shape in CNNPostnet:",outputs.shape) return outputs -- GitLab