From 9d0224460bec81139fd7d69732dce0f7c7ec36fa Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 11 Apr 2022 15:54:44 +0800 Subject: [PATCH] code format, test=doc --- paddlespeech/server/tests/tts/infer/run.sh | 12 ++-- .../server/tests/tts/infer/test_online_tts.py | 67 ++++++++++--------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh index 631daddd..3733c3fb 100644 --- a/paddlespeech/server/tests/tts/infer/run.sh +++ b/paddlespeech/server/tests/tts/infer/run.sh @@ -1,6 +1,6 @@ model_path=~/.paddlespeech/models/ -am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_c -voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan +am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ +voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ testdata=../../../../t2s/exps/csmsc_test.txt # get am file @@ -33,9 +33,13 @@ done # run test -# am can choose fastspeech2_csmsc or fastspeech2-C_csmsc, where fastspeech2-C_csmsc supports streaming inference. +# am can choose fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc, where fastspeech2_cnndecoder_csmsc supports streaming inference. # voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference. -python test_online_tts.py --am fastspeech2-C_csmsc \ +# When am is fastspeech2_cnndecoder_csmsc and am_pad is set to 12, there is no diff between streaming and non-streaming inference results. +# When voc is mb_melgan_csmsc and voc_pad is set to 14, there is no diff between streaming and non-streaming inference results. +# When voc is hifigan_csmsc and voc_pad is set to 20, there is no diff between streaming and non-streaming inference results. + +python test_online_tts.py --am fastspeech2_cnndecoder_csmsc \ --am_config $am_model_dir/$am_config_file \ --am_ckpt $am_model_dir/$am_ckpt_file \ --am_stat $am_model_dir/$am_stat_file \ diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py index 8ccf724b..eb5fc80b 100644 --- a/paddlespeech/server/tests/tts/infer/test_online_tts.py +++ b/paddlespeech/server/tests/tts/infer/test_online_tts.py @@ -34,8 +34,8 @@ from paddlespeech.t2s.utils import str2bool mel_streaming = None wav_streaming = None -stream_first_time = 0.0 -voc_stream_st = 0.0 +streaming_first_time = 0.0 +streaming_voc_st = 0.0 sample_rate = 0 @@ -65,7 +65,7 @@ def get_chunks(data, block_size, pad_size, step): return chunks -def get_stream_am_inference(args, am_config): +def get_streaming_am_inference(args, am_config): with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) @@ -99,8 +99,8 @@ def init(args): frontend = get_frontend(args) # acoustic model - if args.am == 'fastspeech2-C_csmsc': - am, am_mu, am_std = get_stream_am_inference(args, am_config) + if args.am == 'fastspeech2_cnndecoder_csmsc': + am, am_mu, am_std = get_streaming_am_inference(args, am_config) am_infer_info = [am, am_mu, am_std, am_config] else: am_inference, am_name, am_dataset = get_am_inference(args, am_config) @@ -139,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): # 生成完整的mel def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): # 如果是支持流式的AM模型 - if args.am == 'fastspeech2-C_csmsc': + if args.am == 'fastspeech2_cnndecoder_csmsc': am, am_mu, am_std, am_config = am_infer_info orig_hs, h_masks = am.encoder_infer(part_phone_ids) if args.am_streaming: @@ -183,9 +183,9 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): @paddle.no_grad() -def stream_voc_infer(args, voc_infer_info, mel_len): +def streaming_voc_infer(args, voc_infer_info, mel_len): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming voc_inference, voc_config = voc_infer_info block = args.voc_block @@ -203,7 +203,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len): while valid_end <= mel_len: sub_wav = voc_inference(mel_chunk) if flag == 1: - stream_first_time = time.time() + streaming_first_time = time.time() flag = 0 # get valid wav @@ -233,8 +233,8 @@ def stream_voc_infer(args, voc_infer_info, mel_len): @paddle.no_grad() # 非流式AM / 流式AM + 非流式Voc -def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) am_infer_time = time.time() voc_inference, voc_config = voc_infer_info @@ -248,10 +248,10 @@ def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, @paddle.no_grad() # 非流式AM + 流式Voc -def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info, + part_phone_ids, part_tone_ids): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) @@ -260,8 +260,8 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, # voc streaming mel_streaming = mel mel_len = mel.shape[0] - stream_voc_infer(args, voc_infer_info, mel_len) - first_response_time = stream_first_time + streaming_voc_infer(args, voc_infer_info, mel_len) + first_response_time = streaming_first_time wav = wav_streaming final_response_time = time.time() voc_infer_time = final_response_time @@ -271,12 +271,12 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, @paddle.no_grad() # 流式AM + 流式 Voc -def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info, + part_phone_ids, part_tone_ids): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming - global voc_stream_st + global streaming_voc_st mel_streaming = None #用来表示开启流式voc的线程 flag = 1 @@ -311,15 +311,16 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: t = threading.Thread( - target=stream_voc_infer, args=(args, voc_infer_info, mel_len, )) + target=streaming_voc_infer, + args=(args, voc_infer_info, mel_len, )) t.start() - voc_stream_st = time.time() + streaming_voc_st = time.time() flag = 0 t.join() final_response_time = time.time() voc_infer_time = final_response_time - first_response_time = stream_first_time + first_response_time = streaming_first_time wav = wav_streaming return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav @@ -337,11 +338,11 @@ def warm_up(args, logger, frontend, am_infer_info, voc_infer_info): if args.voc_streaming: if args.am_streaming: - infer_func = stream_am_stream_voc + infer_func = streaming_am_streaming_voc else: - infer_func = nostream_am_stream_voc + infer_func = nonstreaming_am_streaming_voc else: - infer_func = am_nostream_voc + infer_func = am_nonstreaming_voc merge_sentences = True get_tone_ids = False @@ -376,11 +377,11 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): # choose infer function if args.voc_streaming: if args.am_streaming: - infer_func = stream_am_stream_voc + infer_func = streaming_am_streaming_voc else: - infer_func = nostream_am_stream_voc + infer_func = nonstreaming_am_streaming_voc else: - infer_func = am_nostream_voc + infer_func = am_nonstreaming_voc final_up_duration = 0.0 sentence_count = 0 @@ -410,7 +411,7 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) am_time = am_infer_time - am_st if args.voc_streaming and args.am_streaming: - voc_time = voc_infer_time - voc_stream_st + voc_time = voc_infer_time - streaming_voc_st else: voc_time = voc_infer_time - am_infer_time @@ -482,8 +483,8 @@ def parse_args(): '--am', type=str, default='fastspeech2_csmsc', - choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'], - help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference' + choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'], + help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference' ) parser.add_argument( @@ -576,7 +577,7 @@ def main(): args = parse_args() paddle.set_device(args.device) if args.am_streaming: - assert (args.am == 'fastspeech2-C_csmsc') + assert (args.am == 'fastspeech2_cnndecoder_csmsc') logger = logging.getLogger() fhandler = logging.FileHandler(filename=args.log_file, mode='w') -- GitLab