diff --git a/ernie-sat/.meta/framework.png b/ernie-sat/.meta/framework.png index c315bc5f32e27954ef1606abd515a4f77d299b21..c68f62467952aaca91f290e5dead723078343d64 100644 Binary files a/ernie-sat/.meta/framework.png and b/ernie-sat/.meta/framework.png differ diff --git a/ernie-sat/README.md b/ernie-sat/README.md index 548688912a6d4158bf0d858bf4a235e731f2e2f3..db8cc07504d45a143557476bc800fcfc075ed706 100644 --- a/ernie-sat/README.md +++ b/ernie-sat/README.md @@ -1,7 +1,7 @@ -ERNIE-SAT是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 +ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 ## 模型框架 -ERNIE-SAT中我们提出了两项创新: +ERNIE-SAT 中我们提出了两项创新: - 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 - 采用语言和语音的联合掩码学习实现了语言和语音的对齐 @@ -12,14 +12,14 @@ ERNIE-SAT中我们提出了两项创新: ### 1.安装飞桨与环境依赖 - 本项目的代码基于 Paddle(version>=2.0) -- 本项目开放提供加载torch版本的vocoder的功能 +- 本项目开放提供加载 torch 版本的 vocoder 的功能 - torch version>=1.8 -- 安装htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的htk(例如3.4.1)。同时提供[历史版本htk下载地址](https://htk.eng.cam.ac.uk/ftp/software/) +- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/) - - 1.注册账号,下载htk - - 2.解压htk文件,**放入项目根目录的tools文件夹中, 以htk文件夹名称放入** - - 3.**注意**: 如果您下载的是3.4.1或者更高版本,需要进入HTKLib/HRec.c文件中, **修改1626行和1650行**, 即把**以下两行的dur<=0 都修改为 dur<0**,如下所示: + - 1.注册账号,下载 htk + - 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入** + - 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示: ```bash 以htk3.4.1版本举例: (1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0"); @@ -28,26 +28,23 @@ ERNIE-SAT中我们提出了两项创新: (2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 "); 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 "); ``` - - 4.**编译**: 详情参见解压后的htk中的README文件(如果未编译, 则无法正常运行) + - 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行) -- 安装ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN项目并且安装相关依赖即可。 - +- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。 - 安装其他依赖: **sox, libsndfile**等 - - ### 2.预训练模型 -预训练模型ERNIE-SAT的模型如下所示: +预训练模型 ERNIE-SAT 的模型如下所示: - [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz) - [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz) - [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz) -创建download文件夹,下载上述ERNIE-SAT预训练模型并将其解压: +创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: ```bash mkdir pretrained_model cd pretrained_model @@ -56,13 +53,12 @@ tar -zxvf model-ernie-sat-base-zh.tar.gz tar -zxvf model-ernie-sat-base-en_zh.tar.gz ``` - ### 3.下载 -1. 本项目使用parallel wavegan作为声码器(vocoder): +1. 本项目使用 parallel wavegan 作为声码器(vocoder): - [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) -创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压: +创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压: ```bash mkdir download @@ -70,11 +66,11 @@ cd download unzip pwg_aishell3_ckpt_0.5.zip ``` - 2. 本项目使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: + 2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用 - [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用 - 下载上述预训练的fastspeech2模型并将其解压 + 下载上述预训练的 fastspeech2 模型并将其解压 ```bash cd download @@ -85,7 +81,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip ### 4.推理 本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 -注:当前英文场下的合成语音采用的声码器默认为vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若use_pt_vocoder参数设置为False,则英文场景下使用paddle版本的声码器。 +注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。 我们提供特定音频文件, 以及其对应的文本、音素相关文件: - prompt_wav: 提供的音频文件 @@ -114,19 +110,19 @@ prompt/dev 5. `--lang` 对应模型的语言可以是 `zh` 或 `en` 。 6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。 7. ` --model_name` 模型名称 -8. ` --uid` 特定提示(prompt)语音的id +8. ` --uid` 特定提示(prompt)语音的 id 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 10. ` --prefix` 特定音频对应的文本、音素相关文件的地址 11. ` --source_language` , 源语言 12. ` --target_language` , 目标语言 13. ` --output_name` , 合成语音名称 14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 -15. ` use_pt_vocoder`,英文场景下是否使用torch版本的vocoder, 默认情况下为True; 设置为False则在英文场景下使用paddle版本vocoder +15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder 运行以下脚本即可进行实验 ```shell -sh run_sedit_en.sh # 语音编辑任务(英文) -sh run_gen_en.sh # 个性化语音合成任务(英文) +sh run_sedit_en.sh # 语音编辑任务(英文) +sh run_gen_en.sh # 个性化语音合成任务(英文) sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) ``` diff --git a/ernie-sat/align_english.py b/ernie-sat/align_english.py index e7661519746106e0ed9e0e5cad48949d98717dfd..aff47fe0514e6f03090cbed9dddaf961537eb186 100755 --- a/ernie-sat/align_english.py +++ b/ernie-sat/align_english.py @@ -1,22 +1,21 @@ #!/usr/bin/env python - """ Usage: - align_english.py wavfile trsfile outwordfile outphonefile + align_english.py wavfile trsfile outwordfile outphonefile """ - +import multiprocessing as mp import os import sys -from tqdm import tqdm -import multiprocessing as mp +from tqdm import tqdm PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' MODEL_DIR = 'tools/aligner/english' HVITE = 'tools/htk/HTKTools/HVite' HCOPY = 'tools/htk/HTKTools/HCopy' + def prep_txt(line, tmpbase, dictfile): - + words = [] line = line.strip() @@ -48,7 +47,8 @@ def prep_txt(line, tmpbase, dictfile): for unk in unk_words: fwid.write(unk + '\n') try: - os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + '_unk.phons') + os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + + '_unk.phons') except: print('english2phoneme error!') sys.exit(1) @@ -79,7 +79,7 @@ def prep_txt(line, tmpbase, dictfile): seq.append(phons[j].upper()) j += 1 else: - p = phons[j:j+2] + p = phons[j:j + 2] if (p == 'WH'): seq.append('W') elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): @@ -96,6 +96,7 @@ def prep_txt(line, tmpbase, dictfile): fw.write('\n') fw.close() + def prep_mlf(txt, tmpbase): with open(tmpbase + '.mlf', 'w') as fwid: @@ -108,6 +109,7 @@ def prep_mlf(txt, tmpbase): fwid.write('sp\n') fwid.write('.\n') + def gen_res(tmpbase, outfile1, outfile2): with open(tmpbase + '.txt', 'r') as fid: words = fid.readline().strip().split() @@ -120,19 +122,20 @@ def gen_res(tmpbase, outfile1, outfile2): times1 = [] times2 = [] while (i < len(lines)): - if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]): + if (len(lines[i].split()) >= 4) and ( + lines[i].split()[0] != lines[i].split()[1]): phn = lines[i].split()[2] - pst = (int(lines[i].split()[0])/1000+125)/10000 - pen = (int(lines[i].split()[1])/1000+125)/10000 + pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 + pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 times2.append([phn, pst, pen]) if (len(lines[i].split()) == 5): if (lines[i].split()[0] != lines[i].split()[1]): wrd = lines[i].split()[-1].strip() - st = (int(lines[i].split()[0])/1000+125)/10000 + st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 j = i + 1 while (lines[j] != '.\n') and (len(lines[j].split()) != 5): j += 1 - en = (int(lines[j-1].split()[1])/1000+125)/10000 + en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000 times1.append([wrd, st, en]) i += 1 @@ -151,8 +154,13 @@ def gen_res(tmpbase, outfile1, outfile2): for item in times2: fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') + +def _get_user(): + return os.path.expanduser('~').split("/")[-1] + + def alignment(wav_path, text_string): - tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid()) + tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) #prepare wav and trs files try: @@ -160,7 +168,7 @@ def alignment(wav_path, text_string): except: print('sox error!') return None - + #prepare clean_transcript file try: prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') @@ -179,14 +187,19 @@ def alignment(wav_path, text_string): #prepare scp try: - os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') + os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + + '.wav' + ' ' + tmpbase + '.plp') except: print('HCopy error!') return None #run alignment try: - os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + + '.plp 2>&1 > /dev/null') except: print('HVite error!') return None @@ -207,15 +220,15 @@ def alignment(wav_path, text_string): splited_line = lines[i].strip().split() if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): phn = splited_line[2] - pst = (int(splited_line[0])/1000+125)/10000 - pen = (int(splited_line[1])/1000+125)/10000 + pst = (int(splited_line[0]) / 1000 + 125) / 10000 + pen = (int(splited_line[1]) / 1000 + 125) / 10000 times2.append([phn, pst, pen]) # splited_line[-1]!='sp' - if len(splited_line)==5: - current_word = str(index)+'_'+splited_line[-1] + if len(splited_line) == 5: + current_word = str(index) + '_' + splited_line[-1] word2phns[current_word] = phn - index+=1 - elif len(splited_line)==4: - word2phns[current_word] += ' '+phn - i+=1 - return times2,word2phns + index += 1 + elif len(splited_line) == 4: + word2phns[current_word] += ' ' + phn + i += 1 + return times2, word2phns diff --git a/ernie-sat/align_mandarin.py b/ernie-sat/align_mandarin.py index 88dad5efa803fec8dd4728a86ff7508e3230a068..fae2a2aea5d66f0853b003b6e5eae81488ff04a2 100755 --- a/ernie-sat/align_mandarin.py +++ b/ernie-sat/align_mandarin.py @@ -1,14 +1,12 @@ #!/usr/bin/env python - """ Usage: - align_mandarin.py wavfile trsfile outwordfile putphonefile + align_mandarin.py wavfile trsfile outwordfile putphonefile """ - +import multiprocessing as mp import os import sys -from tqdm import tqdm -import multiprocessing as mp +from tqdm import tqdm MODEL_DIR = 'tools/aligner/mandarin' HVITE = 'tools/htk/HTKTools/HVite' @@ -19,7 +17,10 @@ def prep_txt(line, tmpbase, dictfile): words = [] line = line.strip() - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: line = line.replace(pun, ' ') for wrd in line.split(): if (wrd[-1] == '-'): @@ -43,6 +44,7 @@ def prep_txt(line, tmpbase, dictfile): fwid.write('\n') return unk_words + def prep_mlf(txt, tmpbase): with open(tmpbase + '.mlf', 'w') as fwid: @@ -55,6 +57,7 @@ def prep_mlf(txt, tmpbase): fwid.write('sp\n') fwid.write('.\n') + def gen_res(tmpbase, outfile1, outfile2): with open(tmpbase + '.txt', 'r') as fid: words = fid.readline().strip().split() @@ -67,19 +70,20 @@ def gen_res(tmpbase, outfile1, outfile2): times1 = [] times2 = [] while (i < len(lines)): - if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]): + if (len(lines[i].split()) >= 4) and ( + lines[i].split()[0] != lines[i].split()[1]): phn = lines[i].split()[2] - pst = (int(lines[i].split()[0])/1000+125)/10000 - pen = (int(lines[i].split()[1])/1000+125)/10000 + pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 + pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 times2.append([phn, pst, pen]) if (len(lines[i].split()) == 5): if (lines[i].split()[0] != lines[i].split()[1]): wrd = lines[i].split()[-1].strip() - st = (int(lines[i].split()[0])/1000+125)/10000 + st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 j = i + 1 while (lines[j] != '.\n') and (len(lines[j].split()) != 5): j += 1 - en = (int(lines[j-1].split()[1])/1000+125)/10000 + en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000 times1.append([wrd, st, en]) i += 1 @@ -99,18 +103,18 @@ def gen_res(tmpbase, outfile1, outfile2): fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') - def alignment_zh(wav_path, text_string): tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid()) #prepare wav and trs files try: - os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + '.wav remix -') + os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + + '.wav remix -') except: print('sox error!') return None - + #prepare clean_transcript file try: unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') @@ -133,14 +137,19 @@ def alignment_zh(wav_path, text_string): #prepare scp try: - os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') + os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + + '.wav' + ' ' + tmpbase + '.plp') except: print('HCopy error!') return None #run alignment try: - os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + + '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + + '.plp 2>&1 > /dev/null') except: print('HVite error!') @@ -156,23 +165,22 @@ def alignment_zh(wav_path, text_string): i = 2 times2 = [] - word2phns = {} + word2phns = {} current_word = '' index = 0 while (i < len(lines)): splited_line = lines[i].strip().split() if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): phn = splited_line[2] - pst = (int(splited_line[0])/1000+125)/10000 - pen = (int(splited_line[1])/1000+125)/10000 + pst = (int(splited_line[0]) / 1000 + 125) / 10000 + pen = (int(splited_line[1]) / 1000 + 125) / 10000 times2.append([phn, pst, pen]) # splited_line[-1]!='sp' - if len(splited_line)==5: - current_word = str(index)+'_'+splited_line[-1] + if len(splited_line) == 5: + current_word = str(index) + '_' + splited_line[-1] word2phns[current_word] = phn - index+=1 - elif len(splited_line)==4: - word2phns[current_word] += ' '+phn - i+=1 - return times2,word2phns - + index += 1 + elif len(splited_line) == 4: + word2phns[current_word] += ' ' + phn + i += 1 + return times2, word2phns diff --git a/ernie-sat/dataset.py b/ernie-sat/dataset.py index 26afbf4eff1364a7bd789ebfc6216e7b482f6f4b..cb1a9b25296117ad4bf924df012578ff8e7dd8ce 100644 --- a/ernie-sat/dataset.py +++ b/ernie-sat/dataset.py @@ -1,9 +1,7 @@ +import math - - -import paddle import numpy as np -import math +import paddle def pad_list(xs, pad_value): @@ -28,23 +26,25 @@ def pad_list(xs, pad_value): """ n_batch = len(xs) max_len = max(paddle.shape(x)[0] for x in xs) - pad = paddle.full((n_batch, max_len), pad_value, dtype = xs[0].dtype) + pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype) for i in range(n_batch): - pad[i, : paddle.shape(xs[i])[0]] = xs[i] - + pad[i, :paddle.shape(xs[i])[0]] = xs[i] + return pad -def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window): + +def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): round = max_len % attention_window if round != 0: max_tlen += (attention_window - round) n_batch = paddle.shape(text)[0] - text_pad = paddle.zeros((n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) + text_pad = paddle.zeros( + (n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) for i in range(n_batch): - text_pad[i, : paddle.shape(text[i])[0]] = text[i] + text_pad[i, :paddle.shape(text[i])[0]] = text[i] else: - text_pad = text[:, : max_tlen] + text_pad = text[:, :max_tlen] return text_pad, max_tlen @@ -139,7 +139,6 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): if not isinstance(lengths, list): lengths = list(lengths) - # print('lengths', lengths) bs = int(len(lengths)) if xs is None: maxlen = int(max(lengths)) @@ -147,10 +146,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): maxlen = paddle.shape(xs)[length_dim] seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) - seq_range_expand = paddle.expand(paddle.unsqueeze(seq_range, 0), (bs, maxlen)) + seq_range_expand = paddle.expand( + paddle.unsqueeze(seq_range, 0), (bs, maxlen)) seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1) - # print('seq_length_expand', paddle.shape(seq_length_expand)) - # print('seq_range_expand', paddle.shape(seq_range_expand)) mask = seq_range_expand >= seq_length_expand if xs is not None: @@ -160,16 +158,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): length_dim = len(paddle.shape(xs)) + length_dim # ind = (:, None, ..., None, :, , None, ..., None) ind = tuple( - slice(None) if i in (0, length_dim) else None for i in range(len(paddle.shape(xs))) - ) - # print('0:', paddle.shape(mask)) - # print('1:', paddle.shape(mask[ind])) - # print('2:', paddle.shape(xs)) + slice(None) if i in (0, length_dim) else None + for i in range(len(paddle.shape(xs)))) mask = paddle.expand(mask[ind], paddle.shape(xs)) return mask - def make_non_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of non-padded part. @@ -259,8 +253,14 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1): return ~make_pad_mask(lengths, xs, length_dim) - -def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary=None): +def phones_masking(xs_pad, + src_mask, + align_start, + align_end, + align_start_lengths, + mlm_prob, + mean_phn_span, + span_boundary=None): bz, sent_len, _ = paddle.shape(xs_pad) mask_num_lower = math.ceil(sent_len * mlm_prob) masked_position = np.zeros((bz, sent_len)) @@ -273,38 +273,41 @@ def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths elif mean_phn_span == 0: # only speech length = sent_len - mean_phn_span = min(length*mlm_prob//3, 50) - masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() - masked_position[:,masked_phn_indices]=1 + mean_phn_span = min(length * mlm_prob // 3, 50) + masked_phn_indices = random_spans_noise_mask(length, mlm_prob, + mean_phn_span).nonzero() + masked_position[:, masked_phn_indices] = 1 else: for idx in range(bz): if span_boundary is not None: - for s,e in zip(span_boundary[idx][::2], span_boundary[idx][1::2]): + for s, e in zip(span_boundary[idx][::2], + span_boundary[idx][1::2]): masked_position[idx, s:e] = 1 # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] # y_masks[idx, e:, s:e ] = 0 else: length = align_start_lengths[idx].item() - if length<2: + if length < 2: continue - masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() + masked_phn_indices = random_spans_noise_mask( + length, mlm_prob, mean_phn_span).nonzero() masked_start = align_start[idx][masked_phn_indices].tolist() masked_end = align_end[idx][masked_phn_indices].tolist() - for s,e in zip(masked_start, masked_end): + for s, e in zip(masked_start, masked_end): masked_position[idx, s:e] = 1 # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] # y_masks[idx, e:, s:e ] = 0 - non_eos_mask = np.array(paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu()) + non_eos_mask = np.array( + paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu()) masked_position = masked_position * non_eos_mask # y_masks = src_mask & y_masks.bool() return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks - - -def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb): +def get_segment_pos(speech_pad, text_pad, align_start, align_end, + align_start_lengths, sega_emb): bz, speech_len, _ = speech_pad.size() _, text_len = text_pad.size() @@ -313,7 +316,6 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le text_segment_pos = np.zeros((bz, text_len)).astype('int64') speech_segment_pos = np.zeros((bz, speech_len)).astype('int64') - if not sega_emb: text_segment_pos = paddle.to_tensor(text_segment_pos) speech_segment_pos = paddle.to_tensor(speech_segment_pos) @@ -321,11 +323,11 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le for idx in range(bz): align_length = align_start_lengths[idx].item() for j in range(align_length): - s,e = align_start[idx][j].item(), align_end[idx][j].item() - speech_segment_pos[idx][s:e] = j+1 - text_segment_pos[idx][j] = j+1 - + s, e = align_start[idx][j].item(), align_end[idx][j].item() + speech_segment_pos[idx][s:e] = j + 1 + text_segment_pos[idx][j] = j + 1 + text_segment_pos = paddle.to_tensor(text_segment_pos) speech_segment_pos = paddle.to_tensor(speech_segment_pos) - return speech_segment_pos, text_segment_pos \ No newline at end of file + return speech_segment_pos, text_segment_pos diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py index c51ba304b9e26159cbe451ba3d5c9b1576946abb..e0ad021e968bba65af532f7680d9d3bb0e8f60a8 100644 --- a/ernie-sat/inference.py +++ b/ernie-sat/inference.py @@ -1,95 +1,124 @@ #!/usr/bin/env python3 - -import os -from pathlib import Path -import librosa -import random -import soundfile as sf -import sys -import pickle import argparse +import math +import os +import pickle +import random +import string +import sys +from pathlib import Path from typing import Collection from typing import Dict from typing import List from typing import Tuple from typing import Union + +import librosa +import numpy as np import paddle +import soundfile as sf import torch -import math -import string -import numpy as np + from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model -from read_text import read_2column_text,load_num_sequence_text -from utils import sentence2phns,get_voc_out, evaluate_durations, is_chinese, build_vocoder_from_file -from model_paddle import build_model_from_file -from sedit_arg_parser import parse_args -from paddlespeech.t2s.datasets.get_feats import LogMelFBank -from dataset import pad_list, pad_to_longformer_att_window, make_pad_mask, make_non_pad_mask, phones_masking, get_segment_pos from align_english import alignment from align_mandarin import alignment_zh -from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model +from dataset import get_segment_pos +from dataset import make_non_pad_mask +from dataset import make_pad_mask +from dataset import pad_list +from dataset import pad_to_longformer_att_window +from dataset import phones_masking +from model_paddle import build_model_from_file +from read_text import load_num_sequence_text +from read_text import read_2column_text +from sedit_arg_parser import parse_args +from utils import build_vocoder_from_file +from utils import evaluate_durations +from utils import get_voc_out +from utils import is_chinese +from utils import sentence2phns +from paddlespeech.t2s.datasets.get_feats import LogMelFBank random.seed(0) np.random.seed(0) - PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_ZH = 'tools/aligner/mandarin' -def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str, use_pt_vocoder, duration_preditor_path,sid=None, non_autoreg=True): +def plot_mel_and_vocode_wav(uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + model_name, + wav_path, + full_origin_str, + old_str, + new_str, + use_pt_vocoder, + duration_preditor_path, + sid=None, + non_autoreg=True): wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - old_str, - new_str, - duration_preditor_path, - use_teacher_forcing=non_autoreg, - sid=sid - ) - - - masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[1]].detach().float().cpu().numpy() - + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + model_name, + wav_path, + old_str, + new_str, + duration_preditor_path, + use_teacher_forcing=non_autoreg, + sid=sid) + + masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[ + 1]].detach().float().cpu().numpy() + if target_language == 'english': if use_pt_vocoder: output_feat = output_feat.detach().float().cpu().numpy() - output_feat = torch.tensor(output_feat,dtype=torch.float) + output_feat = torch.tensor(output_feat, dtype=torch.float) vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') - replaced_wav = vocoder(output_feat).detach().float().data.cpu().numpy() + replaced_wav = vocoder( + output_feat).detach().float().data.cpu().numpy() else: output_feat_np = output_feat.detach().float().cpu().numpy() replaced_wav = get_voc_out(output_feat_np, target_language) elif target_language == 'chinese': output_feat_np = output_feat.detach().float().cpu().numpy() - replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_language) - - old_time_boundary = [hop_length * x for x in old_span_boundary] - new_time_boundary = [hop_length * x for x in new_span_boundary] - + replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, + target_language) + + old_time_boundary = [hop_length * x for x in old_span_boundary] + new_time_boundary = [hop_length * x for x in new_span_boundary] + if target_language == 'english': - wav_org_replaced_paddle_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]]) + wav_org_replaced_paddle_voc = np.concatenate([ + wav_org[:old_time_boundary[0]], + replaced_wav[new_time_boundary[0]:new_time_boundary[1]], + wav_org[old_time_boundary[1]:] + ]) - data_dict = { - "origin":wav_org, - "output":wav_org_replaced_paddle_voc} + data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} - elif target_language == 'chinese': - wav_org_replaced_only_mask_fst2_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[old_time_boundary[1]:]]) + elif target_language == 'chinese': + wav_org_replaced_only_mask_fst2_voc = np.concatenate([ + wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, + wav_org[old_time_boundary[1]:] + ]) data_dict = { - "origin":wav_org, - "output": wav_org_replaced_only_mask_fst2_voc,} - - return data_dict, old_span_boundary + "origin": wav_org, + "output": wav_org_replaced_only_mask_fst2_voc, + } + return data_dict, old_span_boundary def get_unk_phns(word_str): @@ -97,7 +126,8 @@ def get_unk_phns(word_str): f = open(tmpbase + 'temp.words', 'w') f.write(word_str) f.close() - os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + 'temp.phons') + os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + + 'temp.phons') f = open(tmpbase + 'temp.phons', 'r') lines2 = f.readline().strip().split() f.close() @@ -116,7 +146,7 @@ def get_unk_phns(word_str): seq.append(phons[j].upper()) j += 1 else: - p = phons[j:j+2] + p = phons[j:j + 2] if (p == 'WH'): seq.append('W') elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): @@ -129,8 +159,9 @@ def get_unk_phns(word_str): phns.extend(seq) return phns + def words2phns(line): - dictfile = MODEL_DIR_EN+'/dict' + dictfile = MODEL_DIR_EN + '/dict' tmpbase = '/tmp/tp.' line = line.strip() words = [] @@ -151,30 +182,33 @@ def words2phns(line): ds.add(word) if word not in word2phns_dict.keys(): word2phns_dict[word] = " ".join(line.split()[1:]) - + phns = [] wrd2phns = {} for index, wrd in enumerate(words): if wrd == '[MASK]': - wrd2phns[str(index)+"_"+wrd] = [wrd] + wrd2phns[str(index) + "_" + wrd] = [wrd] phns.append(wrd) elif (wrd.upper() not in ds): - wrd2phns[str(index)+"_"+wrd.upper()] = get_unk_phns(wrd) + wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd) phns.extend(get_unk_phns(wrd)) else: - wrd2phns[str(index)+"_"+wrd.upper()] = word2phns_dict[wrd.upper()].split() + wrd2phns[str(index) + + "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split() phns.extend(word2phns_dict[wrd.upper()].split()) return phns, wrd2phns - def words2phns_zh(line): - dictfile = MODEL_DIR_ZH+'/dict' + dictfile = MODEL_DIR_ZH + '/dict' tmpbase = '/tmp/tp.' line = line.strip() words = [] - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: line = line.replace(pun, ' ') for wrd in line.split(): if (wrd[-1] == '-'): @@ -183,7 +217,7 @@ def words2phns_zh(line): wrd = wrd[1:] if wrd: words.append(wrd) - + ds = set([]) word2phns_dict = {} with open(dictfile, 'r') as fid: @@ -192,17 +226,17 @@ def words2phns_zh(line): ds.add(word) if word not in word2phns_dict.keys(): word2phns_dict[word] = " ".join(line.split()[1:]) - + phns = [] wrd2phns = {} for index, wrd in enumerate(words): if wrd == '[MASK]': - wrd2phns[str(index)+"_"+wrd] = [wrd] + wrd2phns[str(index) + "_" + wrd] = [wrd] phns.append(wrd) elif (wrd.upper() not in ds): print("出现非法词错误,请输入正确的文本...") else: - wrd2phns[str(index)+"_"+wrd] = word2phns_dict[wrd].split() + wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split() phns.extend(word2phns_dict[wrd].split()) return phns, wrd2phns @@ -212,62 +246,67 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_file = download_pretrained_model(vocoder_tag) vocoder_config = Path(vocoder_file).parent / "config.yml" - vocoder = build_vocoder_from_file( - vocoder_config, vocoder_file, None, 'cpu' - ) + vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu') return vocoder + def load_model(model_name): - config_path='./pretrained_model/{}/config.yaml'.format(model_name) + config_path = './pretrained_model/{}/config.yaml'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name) - mlm_model, args = build_model_from_file(config_file=config_path, - model_file=model_path) + mlm_model, args = build_model_from_file( + config_file=config_path, model_file=model_path) return mlm_model, args -def read_data(uid,prefix): - mfa_text = read_2column_text(prefix+'/text')[uid] - mfa_wav_path = read_2column_text(prefix+'/wav.scp')[uid] +def read_data(uid, prefix): + mfa_text = read_2column_text(prefix + '/text')[uid] + mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid] if 'mnt' not in mfa_wav_path: mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path return mfa_text, mfa_wav_path - -def get_align_data(uid,prefix): - mfa_path = prefix+"mfa_" - mfa_text = read_2column_text(mfa_path+'text')[uid] - mfa_start = load_num_sequence_text(mfa_path+'start',loader_type='text_float')[uid] - mfa_end = load_num_sequence_text(mfa_path+'end',loader_type='text_float')[uid] - mfa_wav_path = read_2column_text(mfa_path+'wav.scp')[uid] + + +def get_align_data(uid, prefix): + mfa_path = prefix + "mfa_" + mfa_text = read_2column_text(mfa_path + 'text')[uid] + mfa_start = load_num_sequence_text( + mfa_path + 'start', loader_type='text_float')[uid] + mfa_end = load_num_sequence_text( + mfa_path + 'end', loader_type='text_float')[uid] + mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid] return mfa_text, mfa_start, mfa_end, mfa_wav_path -def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced): - align_start=paddle.to_tensor(mfa_start).unsqueeze(0) - align_end =paddle.to_tensor(mfa_end).unsqueeze(0) - align_start = paddle.floor(fs*align_start/hop_length).int() - align_end = paddle.floor(fs*align_end/hop_length).int() - if span_tobe_replaced[0]>=len(mfa_start): - span_boundary = [align_end[0].tolist()[-1],align_end[0].tolist()[-1]] +def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, + span_tobe_replaced): + align_start = paddle.to_tensor(mfa_start).unsqueeze(0) + align_end = paddle.to_tensor(mfa_end).unsqueeze(0) + align_start = paddle.floor(fs * align_start / hop_length).int() + align_end = paddle.floor(fs * align_end / hop_length).int() + if span_tobe_replaced[0] >= len(mfa_start): + span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] else: - span_boundary=[align_start[0].tolist()[span_tobe_replaced[0]],align_end[0].tolist()[span_tobe_replaced[1]-1]] + span_boundary = [ + align_start[0].tolist()[span_tobe_replaced[0]], + align_end[0].tolist()[span_tobe_replaced[1] - 1] + ] return span_boundary - def recover_dict(word2phns, tp_word2phns): dic = {} need_del_key = [] - exist_index = [] - sp_count = 0 - add_sp_count = 0 + exist_index = [] + sp_count = 0 + add_sp_count = 0 for key in word2phns.keys(): idx, wrd = key.split('_') if wrd == 'sp': - sp_count += 1 + sp_count += 1 exist_index.append(int(idx)) else: need_del_key.append(key) - + for key in need_del_key: del word2phns[key] @@ -275,35 +314,36 @@ def recover_dict(word2phns, tp_word2phns): for key in tp_word2phns.keys(): # print("debug: ",key) if cur_id in exist_index: - dic[str(cur_id)+"_sp"] = 'sp' - cur_id += 1 - add_sp_count += 1 + dic[str(cur_id) + "_sp"] = 'sp' + cur_id += 1 + add_sp_count += 1 idx, wrd = key.split('_') - dic[str(cur_id)+"_"+wrd] = tp_word2phns[key] + dic[str(cur_id) + "_" + wrd] = tp_word2phns[key] cur_id += 1 - + if add_sp_count + 1 == sp_count: - dic[str(cur_id)+"_sp"] = 'sp' - add_sp_count += 1 - + dic[str(cur_id) + "_sp"] = 'sp' + add_sp_count += 1 + assert add_sp_count == sp_count, "sp are not added in dic" return dic -def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target_language): +def get_phns_and_spans(wav_path, old_str, new_str, source_language, + clone_target_language): append_new_str = (old_str == new_str[:len(old_str)]) old_phns, mfa_start, mfa_end = [], [], [] if source_language == "english": - times2,word2phns = alignment(wav_path, old_str) + times2, word2phns = alignment(wav_path, old_str) elif source_language == "chinese": - times2,word2phns = alignment_zh(wav_path, old_str) - _, tp_word2phns = words2phns_zh(old_str) - - for key,value in tp_word2phns.items(): + times2, word2phns = alignment_zh(wav_path, old_str) + _, tp_word2phns = words2phns_zh(old_str) + + for key, value in tp_word2phns.items(): idx, wrd = key.split('_') cur_val = " ".join(value) - tp_word2phns[key] = cur_val + tp_word2phns[key] = cur_val word2phns = recover_dict(word2phns, tp_word2phns) @@ -315,9 +355,8 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target mfa_end.append(float(item[2])) old_phns.append(item[0]) - if append_new_str and (source_language != clone_target_language): - is_cross_lingual_clone = True + is_cross_lingual_clone = True else: is_cross_lingual_clone = False @@ -326,54 +365,59 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target new_str_append = new_str[len(old_str):] if clone_target_language == "chinese": - new_phns_origin,new_origin_word2phns = words2phns(new_str_origin) - new_phns_append,temp_new_append_word2phns = words2phns_zh(new_str_append) + new_phns_origin, new_origin_word2phns = words2phns(new_str_origin) + new_phns_append, temp_new_append_word2phns = words2phns_zh( + new_str_append) elif clone_target_language == "english": - new_phns_origin,new_origin_word2phns = words2phns_zh(new_str_origin) # 原始句子 - new_phns_append,temp_new_append_word2phns = words2phns(new_str_append) # clone句子 + new_phns_origin, new_origin_word2phns = words2phns_zh( + new_str_origin) # 原始句子 + new_phns_append, temp_new_append_word2phns = words2phns( + new_str_append) # clone句子 else: assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it." - + new_phns = new_phns_origin + new_phns_append new_append_word2phns = {} length = len(new_origin_word2phns) - for key,value in temp_new_append_word2phns.items(): + for key, value in temp_new_append_word2phns.items(): idx, wrd = key.split('_') - new_append_word2phns[str(int(idx)+length)+'_'+wrd] = value - - new_word2phns = dict(list(new_origin_word2phns.items()) + list(new_append_word2phns.items())) + new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value - else: + new_word2phns = dict( + list(new_origin_word2phns.items()) + list( + new_append_word2phns.items())) + + else: if source_language == clone_target_language and clone_target_language == "english": new_phns, new_word2phns = words2phns(new_str) elif source_language == clone_target_language and clone_target_language == "chinese": new_phns, new_word2phns = words2phns_zh(new_str) else: assert source_language == clone_target_language, "source language is not same with target language..." - - span_tobe_replaced = [0,len(old_phns)-1] - span_tobe_added = [0,len(new_phns)-1] + + span_tobe_replaced = [0, len(old_phns) - 1] + span_tobe_added = [0, len(new_phns) - 1] left_index = 0 new_phns_left = [] sp_count = 0 # find the left different index for key in word2phns.keys(): idx, wrd = key.split('_') - if wrd=='sp': - sp_count +=1 + if wrd == 'sp': + sp_count += 1 new_phns_left.append('sp') else: idx = str(int(idx) - sp_count) - if idx+'_'+wrd in new_word2phns: - left_index+=len(new_word2phns[idx+'_'+wrd]) + if idx + '_' + wrd in new_word2phns: + left_index += len(new_word2phns[idx + '_' + wrd]) new_phns_left.extend(word2phns[key].split()) else: span_tobe_replaced[0] = len(new_phns_left) span_tobe_added[0] = len(new_phns_left) break - + # reverse word2phns and new_word2phns right_index = 0 new_phns_right = [] @@ -381,7 +425,7 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0]) new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0]) new_phns_middle = [] - if append_new_str: + if append_new_str: new_phns_right = [] new_phns_middle = new_phns[left_index:] span_tobe_replaced[0] = len(new_phns_left) @@ -391,176 +435,306 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target else: for key in list(word2phns.keys())[::-1]: idx, wrd = key.split('_') - if wrd=='sp': - sp_count +=1 - new_phns_right = ['sp']+new_phns_right + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right else: - idx = str(new_word2phns_max_index-(word2phns_max_index-int(idx)-sp_count)) - if idx+'_'+wrd in new_word2phns: - right_index-=len(new_word2phns[idx+'_'+wrd]) + idx = str(new_word2phns_max_index - (word2phns_max_index - int( + idx) - sp_count)) + if idx + '_' + wrd in new_word2phns: + right_index -= len(new_word2phns[idx + '_' + wrd]) new_phns_right = word2phns[key].split() + new_phns_right else: span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) new_phns_middle = new_phns[left_index:right_index] - span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) + span_tobe_added[1] = len(new_phns_left) + len( + new_phns_middle) if len(new_phns_middle) == 0: - span_tobe_added[1] = min(span_tobe_added[1]+1, len(new_phns)) - span_tobe_added[0] = max(0, span_tobe_added[0]-1) - span_tobe_replaced[0] = max(0, span_tobe_replaced[0]-1) - span_tobe_replaced[1] = min(span_tobe_replaced[1]+1, len(old_phns)) + span_tobe_added[1] = min(span_tobe_added[1] + 1, + len(new_phns)) + span_tobe_added[0] = max(0, span_tobe_added[0] - 1) + span_tobe_replaced[0] = max(0, + span_tobe_replaced[0] - 1) + span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1, + len(old_phns)) break - new_phns = new_phns_left+new_phns_middle+new_phns_right - + new_phns = new_phns_left + new_phns_middle + new_phns_right return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added - def duration_adjust_factor(original_dur, pred_dur, phns): length = 0 accumulate = 0 factor_list = [] - for ori,pred,phn in zip(original_dur, pred_dur,phns): - if pred==0 or phn=='sp': + for ori, pred, phn in zip(original_dur, pred_dur, phns): + if pred == 0 or phn == 'sp': continue else: - factor_list.append(ori/pred) + factor_list.append(ori / pred) factor_list = np.array(factor_list) factor_list.sort() - if len(factor_list)<5: + if len(factor_list) < 5: return 1 length = 2 return np.average(factor_list[length:-length]) -def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, new_str, wav_path,duration_preditor_path,sid=None, mask_reconstruct=False,duration_adjust=True,start_end_sp=False, train_args=None): - wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) + +def prepare_features_with_duration(uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + old_str, + new_str, + wav_path, + duration_preditor_path, + sid=None, + mask_reconstruct=False, + duration_adjust=True, + start_end_sp=False, + train_args=None): + wav_org, rate = librosa.load( + wav_path, sr=train_args.feats_extract_conf['fs']) fs = train_args.feats_extract_conf['fs'] hop_length = train_args.feats_extract_conf['hop_length'] - - mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(wav_path, old_str, new_str, source_language, target_language) + + mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans( + wav_path, old_str, new_str, source_language, target_language) if start_end_sp: - if new_phns[-1]!='sp': - new_phns = new_phns+['sp'] - + if new_phns[-1] != 'sp': + new_phns = new_phns + ['sp'] + if target_language == "english": - old_durations = evaluate_durations(old_phns, target_language=target_language) + old_durations = evaluate_durations( + old_phns, target_language=target_language) - elif target_language =="chinese": + elif target_language == "chinese": if source_language == "english": - old_durations = evaluate_durations(old_phns, target_language=source_language) + old_durations = evaluate_durations( + old_phns, target_language=source_language) elif source_language == "chinese": - old_durations = evaluate_durations(old_phns, target_language=source_language) + old_durations = evaluate_durations( + old_phns, target_language=source_language) else: assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." - original_old_durations = [e-s for e,s in zip(mfa_end, mfa_start)] + original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)] if '[MASK]' in new_str: new_phns = old_phns span_tobe_added = span_tobe_replaced - d_factor_left = duration_adjust_factor(original_old_durations[:span_tobe_replaced[0]],old_durations[:span_tobe_replaced[0]], old_phns[:span_tobe_replaced[0]]) - d_factor_right = duration_adjust_factor(original_old_durations[span_tobe_replaced[1]:],old_durations[span_tobe_replaced[1]:], old_phns[span_tobe_replaced[1]:]) - d_factor = (d_factor_left+d_factor_right)/2 - new_durations_adjusted = [d_factor*i for i in old_durations] + d_factor_left = duration_adjust_factor( + original_old_durations[:span_tobe_replaced[0]], + old_durations[:span_tobe_replaced[0]], + old_phns[:span_tobe_replaced[0]]) + d_factor_right = duration_adjust_factor( + original_old_durations[span_tobe_replaced[1]:], + old_durations[span_tobe_replaced[1]:], + old_phns[span_tobe_replaced[1]:]) + d_factor = (d_factor_left + d_factor_right) / 2 + new_durations_adjusted = [d_factor * i for i in old_durations] else: if duration_adjust: - d_factor = duration_adjust_factor(original_old_durations,old_durations, old_phns) - d_factor_paddle = duration_adjust_factor(original_old_durations,old_durations, old_phns) - d_factor = d_factor * 1.25 + d_factor = duration_adjust_factor(original_old_durations, + old_durations, old_phns) + d_factor_paddle = duration_adjust_factor(original_old_durations, + old_durations, old_phns) + d_factor = d_factor * 1.25 else: d_factor = 1 - - if target_language == "english": - new_durations = evaluate_durations(new_phns, target_language=target_language) + if target_language == "english": + new_durations = evaluate_durations( + new_phns, target_language=target_language) - elif target_language =="chinese": - new_durations = evaluate_durations(new_phns, target_language=target_language) + elif target_language == "chinese": + new_durations = evaluate_durations( + new_phns, target_language=target_language) - new_durations_adjusted = [d_factor*i for i in new_durations] + new_durations_adjusted = [d_factor * i for i in new_durations] - if span_tobe_replaced[0]=len(mfa_start): + if span_tobe_replaced[0] >= len(mfa_start): left_index = len(wav_org) right_index = left_index else: - left_index = int(np.floor(mfa_start[span_tobe_replaced[0]]*fs)) - right_index = int(np.ceil(mfa_end[span_tobe_replaced[1]-1]*fs)) - new_blank_wav = np.zeros((int(np.ceil(new_span_duration_sum*fs)),), dtype=wav_org.dtype) - new_wav_org = np.concatenate([wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) - + left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs)) + right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs)) + new_blank_wav = np.zeros( + (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype) + new_wav_org = np.concatenate( + [wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) # 4. get old and new mel span to be mask - old_span_boundary = get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] - new_span_boundary=get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, hop_length, span_tobe_added) # [92, 174] - - + old_span_boundary = get_masked_mel_boundary( + mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] + new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, + hop_length, + span_tobe_added) # [92, 174] + return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary -def prepare_features(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor, wav_path, old_str,new_str,duration_preditor_path, sid=None,duration_adjust=True,start_end_sp=False, -mask_reconstruct=False, train_args=None): - wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, - new_str, wav_path,duration_preditor_path,sid=sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp,mask_reconstruct=mask_reconstruct, train_args = train_args) - speech = np.array(wav_org,dtype=np.float32) - align_start=np.array(mfa_start) - align_end =np.array(mfa_end) + +def prepare_features(uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + processor, + wav_path, + old_str, + new_str, + duration_preditor_path, + sid=None, + duration_adjust=True, + start_end_sp=False, + mask_reconstruct=False, + train_args=None): + wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration( + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + old_str, + new_str, + wav_path, + duration_preditor_path, + sid=sid, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + mask_reconstruct=mask_reconstruct, + train_args=train_args) + speech = np.array(wav_org, dtype=np.float32) + align_start = np.array(mfa_start) + align_end = np.array(mfa_end) token_to_id = {item: i for i, item in enumerate(train_args.token_list)} - text = np.array(list(map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) + text = np.array( + list( + map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) # print('unk id is', token_to_id['']) # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) span_boundary = np.array(new_span_boundary) - batch=[('1', {"speech":speech,"align_start":align_start,"align_end":align_end,"text":text,"span_boundary":span_boundary})] - + batch = [('1', { + "speech": speech, + "align_start": align_start, + "align_end": align_end, + "text": text, + "span_boundary": span_boundary + })] + return batch, old_span_boundary, new_span_boundary -def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False,duration_adjust=True,start_end_sp=False, train_args=None): - fs, hop_length = train_args.feats_extract_conf['fs'], train_args.feats_extract_conf['hop_length'] - batch,old_span_boundary,new_span_boundary = prepare_features(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor,wav_path,old_str,new_str,duration_preditor_path, sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args=train_args) - +def decode_with_model(uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + processor, + collate_fn, + wav_path, + old_str, + new_str, + duration_preditor_path, + sid=None, + decoder=False, + use_teacher_forcing=False, + duration_adjust=True, + start_end_sp=False, + train_args=None): + fs, hop_length = train_args.feats_extract_conf[ + 'fs'], train_args.feats_extract_conf['hop_length'] + + batch, old_span_boundary, new_span_boundary = prepare_features( + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + processor, + wav_path, + old_str, + new_str, + duration_preditor_path, + sid, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + train_args=train_args) + feats = collate_fn(batch)[1] if 'text_masked_position' in feats.keys(): feats.pop('text_masked_position') for k, v in feats.items(): feats[k] = paddle.to_tensor(v) - rtn = mlm_model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=use_teacher_forcing) - output = rtn['feat_gen'] + rtn = mlm_model.inference( + **feats, + span_boundary=new_span_boundary, + use_teacher_forcing=use_teacher_forcing) + output = rtn['feat_gen'] if 0 in output[0].shape and 0 not in output[-1].shape: - output_feat = paddle.concat(output[1:-1]+[output[-1].squeeze()], axis=0).cpu() + output_feat = paddle.concat( + output[1:-1] + [output[-1].squeeze()], axis=0).cpu() elif 0 not in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat([output[0].squeeze()]+output[1:-1], axis=0).cpu() + output_feat = paddle.concat( + [output[0].squeeze()] + output[1:-1], axis=0).cpu() elif 0 in output[0].shape and 0 in output[-1].shape: output_feat = paddle.concat(output[1:-1], axis=0).cpu() else: - output_feat = paddle.concat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], axis=0).cpu() - - wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) - origin_speech = paddle.to_tensor(np.array(wav_org,dtype=np.float32)).unsqueeze(0) + output_feat = paddle.concat( + [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], + axis=0).cpu() + + wav_org, rate = librosa.load( + wav_path, sr=train_args.feats_extract_conf['fs']) + origin_speech = paddle.to_tensor( + np.array(wav_org, dtype=np.float32)).unsqueeze(0) speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0) return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length @@ -568,71 +742,64 @@ def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, tar class MLMCollateFn: """Functor class of common_collate_fn()""" - def __init__( - self, - feats_extract, - float_pad_value: Union[float, int] = 0.0, - int_pad_value: int = -32768, - not_sequence: Collection[str] = (), - mlm_prob: float=0.8, - mean_phn_span: int=8, - attention_window: int=0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False - - ): - self.mlm_prob=mlm_prob - self.mean_phn_span=mean_phn_span + def __init__(self, + feats_extract, + float_pad_value: Union[float, int]=0.0, + int_pad_value: int=-32768, + not_sequence: Collection[str]=(), + mlm_prob: float=0.8, + mean_phn_span: int=8, + attention_window: int=0, + pad_speech: bool=False, + sega_emb: bool=False, + duration_collect: bool=False, + text_masking: bool=False): + self.mlm_prob = mlm_prob + self.mean_phn_span = mean_phn_span self.feats_extract = feats_extract self.float_pad_value = float_pad_value self.int_pad_value = int_pad_value self.not_sequence = set(not_sequence) - self.attention_window=attention_window - self.pad_speech=pad_speech - self.sega_emb=sega_emb + self.attention_window = attention_window + self.pad_speech = pad_speech + self.sega_emb = sega_emb self.duration_collect = duration_collect self.text_masking = text_masking def __repr__(self): - return ( - f"{self.__class__}(float_pad_value={self.float_pad_value}, " - f"int_pad_value={self.float_pad_value})" - ) - - def __call__( - self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] - ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + return (f"{self.__class__}(float_pad_value={self.float_pad_value}, " + f"int_pad_value={self.float_pad_value})") + + def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] + ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: return mlm_collate_fn( data, float_pad_value=self.float_pad_value, int_pad_value=self.int_pad_value, not_sequence=self.not_sequence, - mlm_prob=self.mlm_prob, + mlm_prob=self.mlm_prob, mean_phn_span=self.mean_phn_span, feats_extract=self.feats_extract, attention_window=self.attention_window, pad_speech=self.pad_speech, sega_emb=self.sega_emb, duration_collect=self.duration_collect, - text_masking=self.text_masking - ) + text_masking=self.text_masking) + def mlm_collate_fn( - data: Collection[Tuple[str, Dict[str, np.ndarray]]], - float_pad_value: Union[float, int] = 0.0, - int_pad_value: int = -32768, - not_sequence: Collection[str] = (), - mlm_prob: float = 0.8, - mean_phn_span: int = 8, - feats_extract=None, - attention_window: int = 0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False -) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + data: Collection[Tuple[str, Dict[str, np.ndarray]]], + float_pad_value: Union[float, int]=0.0, + int_pad_value: int=-32768, + not_sequence: Collection[str]=(), + mlm_prob: float=0.8, + mean_phn_span: int=8, + feats_extract=None, + attention_window: int=0, + pad_speech: bool=False, + sega_emb: bool=False, + duration_collect: bool=False, + text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]: """Concatenate ndarray-list to an array and convert to torch.Tensor. Examples: @@ -654,9 +821,8 @@ def mlm_collate_fn( data = [d for _, d in data] assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" - assert all( - not k.endswith("_lengths") for k in data[0] - ), f"*_lengths is reserved: {list(data[0])}" + assert all(not k.endswith("_lengths") + for k in data[0]), f"*_lengths is reserved: {list(data[0])}" output = {} for key in data[0]: @@ -679,7 +845,8 @@ def mlm_collate_fn( # lens: (Batch,) if key not in not_sequence: - lens = paddle.to_tensor([d[key].shape[0] for d in data], dtype=paddle.long) + lens = paddle.to_tensor( + [d[key].shape[0] for d in data], dtype=paddle.long) output[key + "_lengths"] = lens feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) @@ -689,71 +856,73 @@ def mlm_collate_fn( feats = paddle.unsqueeze(feats, 0) batch_size = paddle.shape(feats)[0] if 'text' not in output: - text=paddle.zeros_like(feats_lengths.unsqueeze(-1))-2 - text_lengths=paddle.zeros_like(feats_lengths)+1 - max_tlen=1 - align_start=paddle.zeros_like(text) - align_end=paddle.zeros_like(text) - align_start_lengths=paddle.zeros_like(feats_lengths) - align_end_lengths=paddle.zeros_like(feats_lengths) - sega_emb=False + text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2 + text_lengths = paddle.zeros_like(feats_lengths) + 1 + max_tlen = 1 + align_start = paddle.zeros_like(text) + align_end = paddle.zeros_like(text) + align_start_lengths = paddle.zeros_like(feats_lengths) + align_end_lengths = paddle.zeros_like(feats_lengths) + sega_emb = False mean_phn_span = 0 mlm_prob = 0.15 else: text, text_lengths = output["text"], output["text_lengths"] - align_start, align_start_lengths, align_end, align_end_lengths = output["align_start"], output["align_start_lengths"], output["align_end"], output["align_end_lengths"] - align_start = paddle.floor(feats_extract.sr*align_start/feats_extract.hop_length).int() - align_end = paddle.floor(feats_extract.sr*align_end/feats_extract.hop_length).int() + align_start, align_start_lengths, align_end, align_end_lengths = output[ + "align_start"], output["align_start_lengths"], output[ + "align_end"], output["align_end_lengths"] + align_start = paddle.floor(feats_extract.sr * align_start / + feats_extract.hop_length).int() + align_end = paddle.floor(feats_extract.sr * align_end / + feats_extract.hop_length).int() max_tlen = max(text_lengths).item() max_slen = max(feats_lengths).item() - speech_pad = feats[:, : max_slen] - if attention_window>0 and pad_speech: - speech_pad,max_slen = pad_to_longformer_att_window(speech_pad, max_slen, max_slen, attention_window) + speech_pad = feats[:, :max_slen] + if attention_window > 0 and pad_speech: + speech_pad, max_slen = pad_to_longformer_att_window( + speech_pad, max_slen, max_slen, attention_window) max_len = max_slen + max_tlen - if attention_window>0: - text_pad, max_tlen = pad_to_longformer_att_window(text, max_len, max_tlen, attention_window) + if attention_window > 0: + text_pad, max_tlen = pad_to_longformer_att_window( + text, max_len, max_tlen, attention_window) else: text_pad = text - text_mask = make_non_pad_mask(text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) - if attention_window>0: - text_mask = text_mask*2 - speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:,0], length_dim=1).unsqueeze(-2) + text_mask = make_non_pad_mask( + text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) + if attention_window > 0: + text_mask = text_mask * 2 + speech_mask = make_non_pad_mask( + feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) span_boundary = None if 'span_boundary' in output.keys(): span_boundary = output['span_boundary'] if text_masking: - masked_position, text_masked_position,_ = phones_text_masking( - speech_pad, - speech_mask, - text_pad, - text_mask, - align_start, - align_end, - align_start_lengths, - mlm_prob, - mean_phn_span, + masked_position, text_masked_position, _ = phones_text_masking( + speech_pad, speech_mask, text_pad, text_mask, align_start, + align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary) else: text_masked_position = np.zeros(text_pad.size()) masked_position, _ = phones_masking( - speech_pad, - speech_mask, - align_start, - align_end, - align_start_lengths, - mlm_prob, - mean_phn_span, - span_boundary) + speech_pad, speech_mask, align_start, align_end, + align_start_lengths, mlm_prob, mean_phn_span, span_boundary) output_dict = {} if duration_collect and 'text' in output: - reordered_index, speech_segment_pos,text_segment_pos, durations,feats_lengths = get_segment_pos_reduce_duration(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb, masked_position, feats_lengths) - speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:reordered_index.shape[1],0], length_dim=1).unsqueeze(-2) + reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration( + speech_pad, text_pad, align_start, align_end, align_start_lengths, + sega_emb, masked_position, feats_lengths) + speech_mask = make_non_pad_mask( + feats_lengths.tolist(), + speech_pad[:, :reordered_index.shape[1], 0], + length_dim=1).unsqueeze(-2) output_dict['durations'] = durations output_dict['reordered_index'] = reordered_index else: - speech_segment_pos, text_segment_pos = get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb) + speech_segment_pos, text_segment_pos = get_segment_pos( + speech_pad, text_pad, align_start, align_end, align_start_lengths, + sega_emb) output_dict['speech'] = speech_pad output_dict['text'] = text_pad output_dict['masked_position'] = masked_position @@ -767,9 +936,8 @@ def mlm_collate_fn( output = (uttids, output_dict) return output -def build_collate_fn( - args: argparse.Namespace, train: bool, epoch=-1 - ): + +def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): # -> Callable[ # [Collection[Tuple[str, Dict[str, np.ndarray]]]], # Tuple[List[str], Dict[str, torch.Tensor]], @@ -793,68 +961,142 @@ def build_collate_fn( sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False if args.encoder_conf['selfattention_layer_type'] == 'longformer': attention_window = args.encoder_conf['attention_window'] - pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf['pre_speech_layer'] >0 else False + pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[ + 'pre_speech_layer'] > 0 else False else: - attention_window=0 - pad_speech=False - if epoch==-1: + attention_window = 0 + pad_speech = False + if epoch == -1: mlm_prob_factor = 1 else: mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5] - mlm_prob_factor = 0.8 #mlm_probs[epoch // 100] - if 'duration_predictor_layers' in args.model_conf.keys() and args.model_conf['duration_predictor_layers']>0: - duration_collect=True + mlm_prob_factor = 0.8 #mlm_probs[epoch // 100] + if 'duration_predictor_layers' in args.model_conf.keys( + ) and args.model_conf['duration_predictor_layers'] > 0: + duration_collect = True else: - duration_collect=False - - return MLMCollateFn(feats_extract, float_pad_value=0.0, int_pad_value=0, - mlm_prob=args.model_conf['mlm_prob']*mlm_prob_factor,mean_phn_span=args.model_conf['mean_phn_span'],attention_window=attention_window,pad_speech=pad_speech,sega_emb=sega_emb,duration_collect=duration_collect) - + duration_collect = False -def get_mlm_output(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False, dynamic_eval=(0,0),duration_adjust=True,start_end_sp=False): - mlm_model,train_args = load_model(model_name) + return MLMCollateFn( + feats_extract, + float_pad_value=0.0, + int_pad_value=0, + mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor, + mean_phn_span=args.model_conf['mean_phn_span'], + attention_window=attention_window, + pad_speech=pad_speech, + sega_emb=sega_emb, + duration_collect=duration_collect) + + +def get_mlm_output(uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + model_name, + wav_path, + old_str, + new_str, + duration_preditor_path, + sid=None, + decoder=False, + use_teacher_forcing=False, + dynamic_eval=(0, 0), + duration_adjust=True, + start_end_sp=False): + mlm_model, train_args = load_model(model_name) mlm_model.eval() processor = None collate_fn = build_collate_fn(train_args, False) - return decode_with_model(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing, - duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args = train_args) - -def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vocoder, prefix='dump/raw/dev', model_name="conformer", old_str="",new_str="",prompt_decoding=False,dynamic_eval=(0,0), task_name = None): + return decode_with_model( + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + mlm_model, + processor, + collate_fn, + wav_path, + old_str, + new_str, + duration_preditor_path, + sid=sid, + decoder=decoder, + use_teacher_forcing=use_teacher_forcing, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + train_args=train_args) + + +def test_vctk(uid, + clone_uid, + clone_prefix, + source_language, + target_language, + vocoder, + prefix='dump/raw/dev', + model_name="conformer", + old_str="", + new_str="", + prompt_decoding=False, + dynamic_eval=(0, 0), + task_name=None): duration_preditor_path = None spemd = None - full_origin_str,wav_path = read_data(uid, prefix) - + full_origin_str, wav_path = read_data(uid, prefix) + if task_name == 'edit': new_str = new_str elif task_name == 'synthesize': - new_str = full_origin_str + new_str + new_str = full_origin_str + new_str else: - new_str = full_origin_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) - + new_str = full_origin_str + ' '.join( + [ch for ch in new_str if is_chinese(ch)]) + print('new_str is ', new_str) - + if not old_str: old_str = full_origin_str - results_dict, old_span = plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd) + results_dict, old_span = plot_mel_and_vocode_wav( + uid, + prefix, + clone_uid, + clone_prefix, + source_language, + target_language, + model_name, + wav_path, + full_origin_str, + old_str, + new_str, + vocoder, + duration_preditor_path, + sid=spemd) return results_dict + if __name__ == "__main__": + # parse config and args args = parse_args() - print(args) - data_dict = test_vctk(args.uid, - args.clone_uid, - args.clone_prefix, - args.source_language, - args.target_language, + + data_dict = test_vctk( + args.uid, + args.clone_uid, + args.clone_prefix, + args.source_language, + args.target_language, args.use_pt_vocoder, - args.prefix, + args.prefix, args.model_name, new_str=args.new_str, task_name=args.task_name) - sf.write('./wavs/%s' % args.output_name, data_dict['output'], samplerate=24000) + sf.write(args.output_name, data_dict['output'], samplerate=24000) print("finished...") # exit() - diff --git a/ernie-sat/model_paddle.py b/ernie-sat/model_paddle.py index fc3caf93d4614662cb33989934055ec24624868a..a93c7a410595f03d7bebac5d1cbc822fe6fb7472 100644 --- a/ernie-sat/model_paddle.py +++ b/ernie-sat/model_paddle.py @@ -1,80 +1,44 @@ import argparse +import logging +import math +import os +import sys from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Dict from typing import List -from typing import Sequence +from typing import Optional from typing import Tuple from typing import Union -import humanfriendly -from matplotlib.collections import Collection -from matplotlib.pyplot import axis -import librosa -import soundfile as sf import numpy as np import paddle -import paddle.nn.functional as F -from paddle import nn -from typeguard import check_argument_types -import logging -import math import yaml -from abc import ABC, abstractmethod -import warnings -from paddle.amp import auto_cast - -import sys, os +from paddle import nn pypath = '..' for dir_name in os.listdir(pypath): dir_path = os.path.join(pypath, dir_name) if os.path.isdir(dir_path): sys.path.append(dir_path) +from paddlespeech.s2t.utils.error_rate import ErrorCalculator +from paddlespeech.t2s.modules.activation import get_activation +from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule +from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer +from paddlespeech.t2s.modules.masked_fill import masked_fill from paddlespeech.t2s.modules.nets_utils import initialize -from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask -from paddlespeech.t2s.modules.nets_utils import make_pad_mask -from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictor -from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss -from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator -from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor from paddlespeech.t2s.modules.tacotron2.decoder import Postnet -from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder -from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet -from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder -from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder -from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling -from paddlespeech.t2s.modules.masked_fill import masked_fill -from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear, MultiLayeredConv1d -from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule +from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear +from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d from paddlespeech.t2s.modules.transformer.repeat import repeat -from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer from paddlespeech.t2s.modules.layer_norm import LayerNorm -from paddlespeech.s2t.utils.error_rate import ErrorCalculator -from paddlespeech.t2s.datasets.get_feats import LogMelFBank - -class Swish(nn.Layer): - """Construct an Swish object.""" - - def forward(self, x): - """Return Swich activation function.""" - return x * F.sigmoid(x) - -def get_activation(act): - """Return activation function.""" - - activation_funcs = { - "hardtanh": nn.Hardtanh, - "tanh": nn.Tanh, - "relu": nn.ReLU, - "selu": nn.SELU, - "swish": Swish, - } - - return activation_funcs[act]() class LegacyRelPositionalEncoding(PositionalEncoding): """Relative positional encoding module (old version). @@ -89,6 +53,7 @@ class LegacyRelPositionalEncoding(PositionalEncoding): max_len (int): Maximum input length. """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): """ Args: @@ -102,20 +67,18 @@ class LegacyRelPositionalEncoding(PositionalEncoding): """Reset the positional encodings.""" if self.pe is not None: if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]: - # if self.pe.dtype != x.dtype or self.pe.device != x.device: - # self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = paddle.zeros((paddle.shape(x)[1], self.d_model)) if self.reverse: position = paddle.arange( - paddle.shape(x)[1] - 1, -1, -1.0, dtype=paddle.float32 - ).unsqueeze(1) + paddle.shape(x)[1] - 1, -1, -1.0, + dtype=paddle.float32).unsqueeze(1) else: - position = paddle.arange(0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) + position = paddle.arange( + 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) - * -(math.log(10000.0) / self.d_model) - ) + paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + -(math.log(10000.0) / self.d_model)) pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) pe = pe.unsqueeze(0) @@ -129,46 +92,11 @@ class LegacyRelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - self.extend_pe(x) + self.extend_pe(x) x = x * self.xscale pos_emb = self.pe[:, :paddle.shape(x)[1]] return self.dropout(x), self.dropout(pos_emb) -def dump_tensor(var, do_trans = False): - wf = open('/mnt/home/xiaoran/PaddleSpeech-develop/tmp_var.out', 'w') - for num in var.shape: - wf.write(str(num) + ' ') - wf.write('\n') - if do_trans: - var = paddle.transpose(var, [1,0]) - if len(var.shape)==1: - for _var in var: - s = ("%.10f"%_var.item()) - wf.write(s+' ') - elif len(var.shape)==2: - for __var in var: - for _var in __var: - s = ("%.10f"%_var.item()) - wf.write(s+' ') - wf.write('\n') - elif len(var.shape)==3: - for ___var in var: - for __var in ___var: - for _var in __var: - s = ("%.10f"%_var.item()) - wf.write(s+' ') - wf.write('\n') - wf.write('\n') - elif len(var.shape)==4: - for ____var in var: - for ___var in ____var: - for __var in ___var: - for _var in __var: - s = ("%.10f"%_var.item()) - wf.write(s+' ') - wf.write('\n') - wf.write('\n') - wf.write('\n') class mySequential(nn.Sequential): def forward(self, *inputs): @@ -179,24 +107,29 @@ class mySequential(nn.Sequential): inputs = module(inputs) return inputs + class NewMaskInputLayer(nn.Layer): __constants__ = ['out_features'] out_features: int - def __init__(self, out_features: int, - device=None, dtype=None) -> None: + def __init__(self, out_features: int, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} - super(NewMaskInputLayer, self).__init__() + super().__init__() self.mask_feature = paddle.create_parameter( - shape=(1,1,out_features), - dtype=paddle.float32, - default_initializer=paddle.nn.initializer.Assign(paddle.normal(shape=(1,1,out_features)))) - - def forward(self, input: paddle.Tensor, masked_position=None) -> paddle.Tensor: - masked_position = paddle.expand_as(paddle.unsqueeze(masked_position, -1), input) - masked_input = masked_fill(input, masked_position, 0) + masked_fill(paddle.expand_as(self.mask_feature, input), ~masked_position, 0) + shape=(1, 1, out_features), + dtype=paddle.float32, + default_initializer=paddle.nn.initializer.Assign( + paddle.normal(shape=(1, 1, out_features)))) + + def forward(self, input: paddle.Tensor, + masked_position=None) -> paddle.Tensor: + masked_position = paddle.expand_as( + paddle.unsqueeze(masked_position, -1), input) + masked_input = masked_fill(input, masked_position, 0) + masked_fill( + paddle.expand_as(self.mask_feature, input), ~masked_position, 0) return masked_input + class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): """Multi-Head Attention layer with relative position encoding (old version). Details can be found in https://github.com/espnet/espnet/pull/2816. @@ -266,7 +199,8 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): q = paddle.transpose(q, [0, 2, 1, 3]) n_batch_pos = paddle.shape(pos_emb)[0] - p = paddle.reshape(self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) + p = paddle.reshape( + self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) # (batch, head, time1, d_k) p = paddle.transpose(p, [0, 2, 1, 3]) # (batch, head, time1, d_k) @@ -278,17 +212,20 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - matrix_ac = paddle.matmul(q_with_bias_u, paddle.transpose(k, [0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, + paddle.transpose(k, [0, 1, 3, 2])) # compute matrix b and matrix d # (batch, head, time1, time1) - matrix_bd = paddle.matmul(q_with_bias_v, paddle.transpose(p, [0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, + paddle.transpose(p, [0, 1, 3, 2])) matrix_bd = self.rel_shift(matrix_bd) # (batch, head, time1, time2) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask) + class MLMEncoder(nn.Layer): """Conformer encoder module. @@ -324,40 +261,39 @@ class MLMEncoder(nn.Layer): signature.) """ - def __init__( - self, - idim, - vocab_size=0, - pre_speech_layer: int = 0, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - positional_dropout_rate=0.1, - attention_dropout_rate=0.0, - input_layer="conv2d", - normalize_before=True, - concat_after=False, - positionwise_layer_type="linear", - positionwise_conv_kernel_size=1, - macaron_style=False, - pos_enc_layer_type="abs_pos", - pos_enc_class=None, - selfattention_layer_type="selfattn", - activation_type="swish", - use_cnn_module=False, - zero_triu=False, - cnn_module_kernel=31, - padding_idx=-1, - stochastic_depth_rate=0.0, - intermediate_layers=None, - text_masking = False - ): + + def __init__(self, + idim, + vocab_size=0, + pre_speech_layer: int=0, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + pos_enc_class=None, + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + zero_triu=False, + cnn_module_kernel=31, + padding_idx=-1, + stochastic_depth_rate=0.0, + intermediate_layers=None, + text_masking=False): """Construct an Encoder object.""" - super(MLMEncoder, self).__init__() + super().__init__() self._output_size = attention_dim - self.text_masking=text_masking + self.text_masking = text_masking if self.text_masking: self.text_masking_layer = NewMaskInputLayer(attention_dim) activation = get_activation(activation_type) @@ -381,21 +317,18 @@ class MLMEncoder(nn.Layer): nn.LayerNorm(attention_dim), nn.Dropout(dropout_rate), nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate), - ) + pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer == "conv2d": self.embed = Conv2dSubsampling( idim, attention_dim, dropout_rate, - pos_enc_class(attention_dim, positional_dropout_rate), - ) + pos_enc_class(attention_dim, positional_dropout_rate), ) self.conv_subsampling_factor = 4 elif input_layer == "embed": self.embed = nn.Sequential( nn.Embedding(idim, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), - ) + pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer == "mlm": self.segment_emb = None self.speech_embed = mySequential( @@ -403,34 +336,31 @@ class MLMEncoder(nn.Layer): nn.Linear(idim, attention_dim), nn.LayerNorm(attention_dim), nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate) - ) + pos_enc_class(attention_dim, positional_dropout_rate)) self.text_embed = nn.Sequential( - nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), - ) - elif input_layer=="sega_mlm": - self.segment_emb = nn.Embedding(500, attention_dim, padding_idx=padding_idx) + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "sega_mlm": + self.segment_emb = nn.Embedding( + 500, attention_dim, padding_idx=padding_idx) self.speech_embed = mySequential( NewMaskInputLayer(idim), nn.Linear(idim, attention_dim), nn.LayerNorm(attention_dim), nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate) - ) + pos_enc_class(attention_dim, positional_dropout_rate)) self.text_embed = nn.Sequential( - nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), - ) + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) elif isinstance(input_layer, nn.Layer): self.embed = nn.Sequential( input_layer, - pos_enc_class(attention_dim, positional_dropout_rate), - ) + pos_enc_class(attention_dim, positional_dropout_rate), ) elif input_layer is None: self.embed = nn.Sequential( - pos_enc_class(attention_dim, positional_dropout_rate) - ) + pos_enc_class(attention_dim, positional_dropout_rate)) else: raise ValueError("unknown input_layer: " + input_layer) self.normalize_before = normalize_before @@ -439,57 +369,39 @@ class MLMEncoder(nn.Layer): if selfattention_layer_type == "selfattn": logging.info("encoder self-attention layer type = self-attention") encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - attention_dim, - attention_dropout_rate, - ) + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) elif selfattention_layer_type == "legacy_rel_selfattn": assert pos_enc_layer_type == "legacy_rel_pos" encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - attention_dim, - attention_dropout_rate, - ) + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) elif selfattention_layer_type == "rel_selfattn": - logging.info("encoder self-attention layer type = relative self-attention") + logging.info( + "encoder self-attention layer type = relative self-attention") assert pos_enc_layer_type == "rel_pos" encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - attention_dim, - attention_dropout_rate, - zero_triu, - ) + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, zero_triu, ) else: - raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + raise ValueError("unknown encoder_attn_layer: " + + selfattention_layer_type) # feed-forward module definition if positionwise_layer_type == "linear": positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - attention_dim, - linear_units, - dropout_rate, - activation, - ) + positionwise_layer_args = (attention_dim, linear_units, + dropout_rate, activation, ) elif positionwise_layer_type == "conv1d": positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = ( - attention_dim, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) elif positionwise_layer_type == "conv1d-linear": positionwise_layer = Conv1dLinear - positionwise_layer_args = ( - attention_dim, - linear_units, - positionwise_conv_kernel_size, - dropout_rate, - ) + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) else: raise NotImplementedError("Support only linear or conv1d.") @@ -508,9 +420,7 @@ class MLMEncoder(nn.Layer): dropout_rate, normalize_before, concat_after, - stochastic_depth_rate * float(1 + lnum) / num_blocks, - ), - ) + stochastic_depth_rate * float(1 + lnum) / num_blocks, ), ) self.pre_speech_layer = pre_speech_layer self.pre_speech_encoders = repeat( self.pre_speech_layer, @@ -523,16 +433,21 @@ class MLMEncoder(nn.Layer): dropout_rate, normalize_before, concat_after, - stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, - ), + stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ), ) if self.normalize_before: self.after_norm = LayerNorm(attention_dim) self.intermediate_layers = intermediate_layers - - def forward(self, speech_pad, text_pad, masked_position, speech_mask=None, text_mask=None,speech_segment_pos=None, text_segment_pos=None): + def forward(self, + speech_pad, + text_pad, + masked_position, + speech_mask=None, + text_mask=None, + speech_segment_pos=None, + text_segment_pos=None): """Encode input sequence. """ @@ -542,12 +457,13 @@ class MLMEncoder(nn.Layer): speech_pad = self.speech_embed(speech_pad) # pure speech input if -2 in np.array(text_pad): - text_pad = text_pad+3 + text_pad = text_pad + 3 text_mask = paddle.unsqueeze(bool(text_pad), 1) text_segment_pos = paddle.zeros_like(text_pad) text_pad = self.text_embed(text_pad) - text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), text_pad[1]) - text_segment_pos=None + text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), + text_pad[1]) + text_segment_pos = None elif text_pad is not None: text_pad = self.text_embed(text_pad) segment_emb = None @@ -556,32 +472,32 @@ class MLMEncoder(nn.Layer): text_segment_emb = self.segment_emb(text_segment_pos) text_pad = (text_pad[0] + text_segment_emb, text_pad[1]) speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1]) - segment_emb = paddle.concat([speech_segment_emb, text_segment_emb],axis=1) + segment_emb = paddle.concat( + [speech_segment_emb, text_segment_emb], axis=1) if self.pre_speech_encoders: speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) if text_pad is not None: xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1) xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1) - masks = paddle.concat([speech_mask,text_mask],axis=-1) + masks = paddle.concat([speech_mask, text_mask], axis=-1) else: xs = speech_pad[0] xs_pos_emb = speech_pad[1] masks = speech_mask - xs, masks = self.encoders((xs,xs_pos_emb), masks) + xs, masks = self.encoders((xs, xs_pos_emb), masks) if isinstance(xs, tuple): xs = xs[0] if self.normalize_before: xs = self.after_norm(xs) - return xs, masks #, segment_emb + return xs, masks #, segment_emb class MLMDecoder(MLMEncoder): - - def forward(self, xs, masks, masked_position=None,segment_emb=None): + def forward(self, xs, masks, masked_position=None, segment_emb=None): """Encode input sequence. Args: @@ -596,9 +512,6 @@ class MLMDecoder(MLMEncoder): emb, mlm_position = None, None if not self.training: masked_position = None - # if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): - # xs, masks = self.embed(xs, masks) - # else: xs = self.embed(xs) if segment_emb: xs = (xs[0] + segment_emb, xs[1]) @@ -609,10 +522,8 @@ class MLMDecoder(MLMEncoder): for layer_idx, encoder_layer in enumerate(self.encoders): xs, masks = encoder_layer(xs, masks) - if ( - self.intermediate_layers is not None - and layer_idx + 1 in self.intermediate_layers - ): + if (self.intermediate_layers is not None and + layer_idx + 1 in self.intermediate_layers): encoder_output = xs # intermediate branches also require normalization. if self.normalize_before: @@ -627,104 +538,44 @@ class MLMDecoder(MLMEncoder): return xs, masks, intermediate_outputs return xs, masks -class AbsESPnetModel(nn.Layer, ABC): - """The common abstract class among each tasks - - "ESPnetModel" is referred to a class which inherits paddle.nn.Layer, - and makes the dnn-models forward as its member field, - a.k.a delegate pattern, - and defines "loss", "stats", and "weight" for the task. - - If you intend to implement new task in ESPNet, - the model must inherit this class. - In other words, the "mediator" objects between - our training system and the your task class are - just only these three values, loss, stats, and weight. - - Example: - >>> from espnet2.tasks.abs_task import AbsTask - >>> class YourESPnetModel(AbsESPnetModel): - ... def forward(self, input, input_lengths): - ... ... - ... return loss, stats, weight - >>> class YourTask(AbsTask): - ... @classmethod - ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel: - """ - - @abstractmethod - def forward( - self, **batch: paddle.Tensor - ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: - raise NotImplementedError - @abstractmethod - def collect_feats(self, **batch: paddle.Tensor) -> Dict[str, paddle.Tensor]: - raise NotImplementedError - -class AbsFeatsExtract(nn.Layer, ABC): - @abstractmethod - def output_size(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_parameters(self) -> Dict[str, Any]: - raise NotImplementedError - - @abstractmethod - def forward( - self, input: paddle.Tensor, input_lengths: paddle.Tensor - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - raise NotImplementedError - -class AbsNormalize(nn.Layer, ABC): - @abstractmethod - def forward( - self, input: paddle.Tensor, input_lengths: paddle.Tensor = None - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - # return output, output_lengths - raise NotImplementedError - - - -def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window): +def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): round = max_len % attention_window if round != 0: max_tlen += (attention_window - round) n_batch = paddle.shape(text)[0] - text_pad = paddle.zeros(shape = (n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) + text_pad = paddle.zeros( + shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]), + dtype=text.dtype) for i in range(n_batch): - text_pad[i, : paddle.shape(text[i])[0]] = text[i] + text_pad[i, :paddle.shape(text[i])[0]] = text[i] else: - text_pad = text[:, : max_tlen] + text_pad = text[:, :max_tlen] return text_pad, max_tlen -class ESPnetMLMModel(AbsESPnetModel): - def __init__( - self, - token_list: Union[Tuple[str, ...], List[str]], - odim: int, - feats_extract: Optional[AbsFeatsExtract], - normalize: Optional[AbsNormalize], - encoder: nn.Layer, - decoder: Optional[nn.Layer], - postnet_layers: int = 0, - postnet_chans: int = 0, - postnet_filts: int = 0, - ignore_id: int = -1, - lsm_weight: float = 0.0, - length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - masking_schema: str = "span", - mean_phn_span: int = 3, - mlm_prob: float = 0.25, - dynamic_mlm_prob = False, - decoder_seg_pos=False, - text_masking=False - ): + +class MLMModel(nn.Layer): + def __init__(self, + token_list: Union[Tuple[str, ...], List[str]], + odim: int, + encoder: nn.Layer, + decoder: Optional[nn.Layer], + postnet_layers: int=0, + postnet_chans: int=0, + postnet_filts: int=0, + ignore_id: int=-1, + lsm_weight: float=0.0, + length_normalized_loss: bool=False, + report_cer: bool=True, + report_wer: bool=True, + sym_space: str="", + sym_blank: str="", + masking_schema: str="span", + mean_phn_span: int=3, + mlm_prob: float=0.25, + dynamic_mlm_prob=False, + decoder_seg_pos=False, + text_masking=False): super().__init__() # note that eos is the same as sos (equivalent ID) @@ -732,105 +583,119 @@ class ESPnetMLMModel(AbsESPnetModel): self.ignore_id = ignore_id self.token_list = token_list.copy() - self.normalize = normalize self.encoder = encoder self.decoder = decoder self.vocab_size = encoder.text_embed[0]._num_embeddings if report_cer or report_wer: self.error_calculator = ErrorCalculator( - token_list, sym_space, sym_blank, report_cer, report_wer - ) + token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None - self.feats_extract = feats_extract self.mlm_weight = 1.0 self.mlm_prob = mlm_prob self.mlm_layer = 12 - self.finetune_wo_mlm =True + self.finetune_wo_mlm = True self.max_span = 50 self.min_span = 4 self.mean_phn_span = mean_phn_span self.masking_schema = masking_schema - if self.decoder is None or not (hasattr(self.decoder, 'output_layer') and self.decoder.output_layer is not None): + if self.decoder is None or not (hasattr(self.decoder, + 'output_layer') and + self.decoder.output_layer is not None): self.sfc = nn.Linear(self.encoder._output_size, odim) else: - self.sfc=None + self.sfc = None if text_masking: - self.text_sfc = nn.Linear(self.encoder.text_embed[0]._embedding_dim, self.vocab_size, weight_attr = self.encoder.text_embed[0]._weight_attr) + self.text_sfc = nn.Linear( + self.encoder.text_embed[0]._embedding_dim, + self.vocab_size, + weight_attr=self.encoder.text_embed[0]._weight_attr) self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) else: self.text_sfc = None self.text_mlm_loss = None self.decoder_seg_pos = decoder_seg_pos if lsm_weight > 50: - self.l1_loss_func = nn.MSELoss(reduce=False) + self.l1_loss_func = nn.MSELoss() else: self.l1_loss_func = nn.L1Loss(reduction='none') - self.postnet = ( - None - if postnet_layers == 0 - else Postnet( - idim=self.encoder._output_size, - odim=odim, - n_layers=postnet_layers, - n_chans=postnet_chans, - n_filts=postnet_filts, - use_batch_norm=True, - dropout_rate=0.5, - ) - ) + self.postnet = (None if postnet_layers == 0 else Postnet( + idim=self.encoder._output_size, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=True, + dropout_rate=0.5, )) def collect_feats(self, - speech, speech_lengths, text, text_lengths, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos, y_masks=None - ) -> Dict[str, paddle.Tensor]: + speech, + speech_lengths, + text, + text_lengths, + masked_position, + speech_mask, + text_mask, + speech_segment_pos, + text_segment_pos, + y_masks=None) -> Dict[str, paddle.Tensor]: return {"feats": speech, "feats_lengths": speech_lengths} - def _forward(self, batch, speech_segment_pos,y_masks=None): + def forward(self, batch, speech_segment_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) speech_pad_placeholder = batch['speech_pad'] if self.decoder is not None: - ys_in = self._add_first_frame_and_remove_last_frame(batch['speech_pad']) + ys_in = self._add_first_frame_and_remove_last_frame( + batch['speech_pad']) encoder_out, h_masks = self.encoder(**batch) if self.decoder is not None: - zs, _ = self.decoder(ys_in, y_masks, encoder_out, bool(h_masks), self.encoder.segment_emb(speech_segment_pos)) + zs, _ = self.decoder(ys_in, y_masks, encoder_out, + bool(h_masks), + self.encoder.segment_emb(speech_segment_pos)) speech_hidden_states = zs else: - speech_hidden_states = encoder_out[:,:paddle.shape(batch['speech_pad'])[1], :] + speech_hidden_states = encoder_out[:, :paddle.shape(batch[ + 'speech_pad'])[1], :] if self.sfc is not None: - before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) else: before_outs = speech_hidden_states if self.postnet is not None: - after_outs = before_outs + paddle.transpose(self.postnet( - paddle.transpose(before_outs, [0, 2, 1]) - ), (0, 2, 1)) + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + (0, 2, 1)) else: after_outs = None - return before_outs, after_outs, speech_pad_placeholder, batch['masked_position'] + return before_outs, after_outs, speech_pad_placeholder, batch[ + 'masked_position'] - - - def inference( - self, - speech, text, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos, - span_boundary, - y_masks=None, - speech_lengths=None, text_lengths=None, - feats: Optional[paddle.Tensor] = None, - spembs: Optional[paddle.Tensor] = None, - sids: Optional[paddle.Tensor] = None, - lids: Optional[paddle.Tensor] = None, - threshold: float = 0.5, - minlenratio: float = 0.0, - maxlenratio: float = 10.0, - use_teacher_forcing: bool = False, - ) -> Dict[str, paddle.Tensor]: - - + self, + speech, + text, + masked_position, + speech_mask, + text_mask, + speech_segment_pos, + text_segment_pos, + span_boundary, + y_masks=None, + speech_lengths=None, + text_lengths=None, + feats: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + sids: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + threshold: float=0.5, + minlenratio: float=0.0, + maxlenratio: float=10.0, + use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + batch = dict( speech_pad=speech, text_pad=text, @@ -838,119 +703,130 @@ class ESPnetMLMModel(AbsESPnetModel): speech_mask=speech_mask, text_mask=text_mask, speech_segment_pos=speech_segment_pos, - text_segment_pos=text_segment_pos, - ) - + text_segment_pos=text_segment_pos, ) # # inference with teacher forcing # hs, h_masks = self.encoder(**batch) - outs = [batch['speech_pad'][:,:span_boundary[0]]] + outs = [batch['speech_pad'][:, :span_boundary[0]]] z_cache = None if use_teacher_forcing: - before,zs, _, _ = self._forward( + before, zs, _, _ = self.forward( batch, speech_segment_pos, y_masks=y_masks) if zs is None: zs = before - outs+=[zs[0][span_boundary[0]:span_boundary[1]]] - outs+=[batch['speech_pad'][:,span_boundary[1]:]] + outs += [zs[0][span_boundary[0]:span_boundary[1]]] + outs += [batch['speech_pad'][:, span_boundary[1]:]] return dict(feat_gen=outs) - - # concatenate attention weights -> (#layers, #heads, T_feats, T_text) - att_ws = paddle.stack(att_ws, axis=0) - outs += [batch['speech_pad'][:,span_boundary[1]:]] - return dict(feat_gen=outs, att_w=att_ws) + return None - - def _add_first_frame_and_remove_last_frame(self, ys: paddle.Tensor) -> paddle.Tensor: + def _add_first_frame_and_remove_last_frame( + self, ys: paddle.Tensor) -> paddle.Tensor: ys_in = paddle.concat( - [paddle.zeros(shape = (paddle.shape(ys)[0], 1, paddle.shape(ys)[2]), dtype = ys.dtype), ys[:, :-1]], axis=1 - ) + [ + paddle.zeros( + shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]), + dtype=ys.dtype), ys[:, :-1] + ], + axis=1) return ys_in -class ESPnetMLMEncAsDecoderModel(ESPnetMLMModel): - - def _forward(self, batch, speech_segment_pos, y_masks=None): +class MLMEncAsDecoderModel(MLMModel): + def forward(self, batch, speech_segment_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) speech_pad_placeholder = batch['speech_pad'] - encoder_out, h_masks = self.encoder(**batch) # segment_emb + encoder_out, h_masks = self.encoder(**batch) # segment_emb if self.decoder is not None: zs, _ = self.decoder(encoder_out, h_masks) else: zs = encoder_out - speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :] + speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] if self.sfc is not None: - before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) else: before_outs = speech_hidden_states if self.postnet is not None: - after_outs = before_outs + paddle.transpose(self.postnet( - paddle.transpose(before_outs, [0, 2, 1]) - ), [0, 2, 1]) + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) else: after_outs = None - return before_outs, after_outs, speech_pad_placeholder, batch['masked_position'] + return before_outs, after_outs, speech_pad_placeholder, batch[ + 'masked_position'] -class ESPnetMLMDualMaksingModel(ESPnetMLMModel): - def _calc_mlm_loss( - self, - before_outs: paddle.Tensor, - after_outs: paddle.Tensor, - text_outs: paddle.Tensor, - batch - ): +class MLMDualMaksingModel(MLMModel): + def _calc_mlm_loss(self, + before_outs: paddle.Tensor, + after_outs: paddle.Tensor, + text_outs: paddle.Tensor, + batch): xs_pad = batch['speech_pad'] text_pad = batch['text_pad'] masked_position = batch['masked_position'] text_masked_position = batch['text_masked_position'] - mlm_loss_position = masked_position>0 - loss = paddle.sum(self.l1_loss_func(paddle.reshape(before_outs, (-1, self.odim)), - paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) + mlm_loss_position = masked_position > 0 + loss = paddle.sum( + self.l1_loss_func( + paddle.reshape(before_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) if after_outs is not None: - loss += paddle.sum(self.l1_loss_func(paddle.reshape(after_outs, (-1, self.odim)), - paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) - loss_mlm = paddle.sum((loss * paddle.reshape(mlm_loss_position, axis=-1).float())) \ - / paddle.sum((mlm_loss_position.float()) + 1e-10) - - loss_text = paddle.sum((self.text_mlm_loss(paddle.reshape(text_outs, (-1,self.vocab_size)), paddle.reshape(text_pad, (-1))) * paddle.reshape(text_masked_position, (-1)).float())) \ - / paddle.sum((text_masked_position.float()) + 1e-10) + loss += paddle.sum( + self.l1_loss_func( + paddle.reshape(after_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) + loss_mlm = paddle.sum((loss * paddle.reshape( + mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10) + + loss_text = paddle.sum((self.text_mlm_loss( + paddle.reshape(text_outs, (-1, self.vocab_size)), + paddle.reshape(text_pad, (-1))) * paddle.reshape( + text_masked_position, + (-1)))) / paddle.sum((text_masked_position) + 1e-10) return loss_mlm, loss_text - - def _forward(self, batch, speech_segment_pos, y_masks=None): + def forward(self, batch, speech_segment_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) speech_pad_placeholder = batch['speech_pad'] - encoder_out, h_masks = self.encoder(**batch) # segment_emb + encoder_out, h_masks = self.encoder(**batch) # segment_emb if self.decoder is not None: zs, _ = self.decoder(encoder_out, h_masks) else: zs = encoder_out - speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :] + speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] if self.text_sfc: - text_hiddent_states = zs[:,paddle.shape(batch['speech_pad'])[1]:,:] - text_outs = paddle.reshape(self.text_sfc(text_hiddent_states), (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) + text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[ + 1]:, :] + text_outs = paddle.reshape( + self.text_sfc(text_hiddent_states), + (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) if self.sfc is not None: - before_outs = paddle.reshape(self.sfc(speech_hidden_states), - (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) else: before_outs = speech_hidden_states if self.postnet is not None: - after_outs = before_outs + paddle.transpose(self.postnet( - paddle.transpose(before_outs, [0,2,1]) - ), [0, 2, 1]) + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) else: after_outs = None - return before_outs, after_outs,text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] + return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] + def build_model_from_file(config_file, model_file): - + state_dict = paddle.load(model_file) - model_class = ESPnetMLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ - else ESPnetMLMEncAsDecoderModel + model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ + else MLMEncAsDecoderModel # 构建模型 args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8")) @@ -962,7 +838,8 @@ def build_model_from_file(config_file, model_file): return model, args -def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderModel) -> ESPnetMLMModel: +def build_model(args: argparse.Namespace, + model_class=MLMEncAsDecoderModel) -> MLMModel: if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] @@ -975,17 +852,14 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") - - odim = 80 - - # Normalization layer - normalize = None + odim = 80 pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding if "conformer" == args.encoder: - conformer_self_attn_layer_type = args.encoder_conf['selfattention_layer_type'] + conformer_self_attn_layer_type = args.encoder_conf[ + 'selfattention_layer_type'] conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type'] conformer_rel_pos_type = "legacy" if conformer_rel_pos_type == "legacy": @@ -994,38 +868,42 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " - "please use conformer_pos_enc_layer_type = 'latest'." - ) + "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " - "please use conformer_pos_enc_layer_type = 'latest'." - ) + "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") - args.encoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type - args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type - if "conformer"==args.decoder: - args.decoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type - args.decoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type - + args.encoder_conf[ + 'selfattention_layer_type'] = conformer_self_attn_layer_type + args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type + if "conformer" == args.decoder: + args.decoder_conf[ + 'selfattention_layer_type'] = conformer_self_attn_layer_type + args.decoder_conf[ + 'pos_enc_layer_type'] = conformer_pos_enc_layer_type # Encoder encoder_class = MLMEncoder - if 'text_masking' in args.model_conf.keys() and args.model_conf['text_masking']: + if 'text_masking' in args.model_conf.keys() and args.model_conf[ + 'text_masking']: args.encoder_conf['text_masking'] = True else: args.encoder_conf['text_masking'] = False - - encoder = encoder_class(args.input_size,vocab_size=vocab_size, pos_enc_class=pos_enc_class, - **args.encoder_conf) + + encoder = encoder_class( + args.input_size, + vocab_size=vocab_size, + pos_enc_class=pos_enc_class, + **args.encoder_conf) # Decoder if args.decoder != 'no_decoder': @@ -1033,22 +911,17 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod decoder = decoder_class( idim=0, input_layer=None, - **args.decoder_conf, - ) + **args.decoder_conf, ) else: decoder = None # Build model model = model_class( - feats_extract=None, # maybe should be LogMelFbank odim=odim, - normalize=normalize, encoder=encoder, decoder=decoder, token_list=token_list, - **args.model_conf, - ) - + **args.model_conf, ) # Initialize if args.init is not None: diff --git a/ernie-sat/prompt/dev/mfa_end b/ernie-sat/prompt/dev/mfa_end deleted file mode 100644 index 8772d7e4bd2f5f9c74bb904aceaea5d06e035297..0000000000000000000000000000000000000000 --- a/ernie-sat/prompt/dev/mfa_end +++ /dev/null @@ -1,3 +0,0 @@ -p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525 -Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125 -p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825 diff --git a/ernie-sat/prompt/dev/mfa_start b/ernie-sat/prompt/dev/mfa_start deleted file mode 100644 index fac9716128f5ec175fe0df8acf22969057ef2ab5..0000000000000000000000000000000000000000 --- a/ernie-sat/prompt/dev/mfa_start +++ /dev/null @@ -1,3 +0,0 @@ -p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 -Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 -p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 diff --git a/ernie-sat/prompt/dev/mfa_text b/ernie-sat/prompt/dev/mfa_text deleted file mode 100644 index 903e2444a0c55e694fff9a7edb92398f04bfefec..0000000000000000000000000000000000000000 --- a/ernie-sat/prompt/dev/mfa_text +++ /dev/null @@ -1,3 +0,0 @@ -p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp -Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp -p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp diff --git a/ernie-sat/prompt/dev/mfa_wav.scp b/ernie-sat/prompt/dev/mfa_wav.scp deleted file mode 100644 index eb0e8e48d5c50fe65271a8039bdd24628fa9daf5..0000000000000000000000000000000000000000 --- a/ernie-sat/prompt/dev/mfa_wav.scp +++ /dev/null @@ -1,3 +0,0 @@ -p243_new ../../prompt_wav/p243_313.wav -Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav -p299_096 ../../prompt_wav/p299_096.wav diff --git a/ernie-sat/read_text.py b/ernie-sat/read_text.py index f140c31f24842dab636dd9701c9047ba2b0181dd..bcf1125ef041338eff9efef254caab25311afd31 100644 --- a/ernie-sat/read_text.py +++ b/ernie-sat/read_text.py @@ -5,7 +5,6 @@ from typing import List from typing import Union - def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: """Read a text file having 2 column as dict object. @@ -33,9 +32,8 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: return data -def load_num_sequence_text( - path: Union[Path, str], loader_type: str = "csv_int" -) -> Dict[str, List[Union[float, int]]]: +def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int" + ) -> Dict[str, List[Union[float, int]]]: """Read a text file indicating sequences of number Examples: @@ -73,6 +71,7 @@ def load_num_sequence_text( try: retval[k] = [dtype(i) for i in v.split(delimiter)] except TypeError: - logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"') + logging.error( + f'Error happened with path="{path}", id="{k}", value="{v}"') raise return retval diff --git a/ernie-sat/run_clone_en_to_zh.sh b/ernie-sat/run_clone_en_to_zh.sh old mode 100644 new mode 100755 index a703fa46c8894ed82290fff74dca903a4e0ba7d3..9266b67f1609d338b9fc2ac816dc29ba752eea78 --- a/ernie-sat/run_clone_en_to_zh.sh +++ b/ernie-sat/run_clone_en_to_zh.sh @@ -1,24 +1,25 @@ -# en --> zh 的 语音合成 -# 根据Prompt_003_new作为提示语音: This was not the show for me. 来合成: '今天天气很好' -# 注: 输入的new_str需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 +#!/bin/bash +# en --> zh 的 语音合成 +# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好' +# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 python inference.py \ ---task_name cross-lingual_clone \ ---model_name paddle_checkpoint_dual_mask_enzh \ ---uid Prompt_003_new \ ---new_str '今天天气很好.' \ ---prefix ./prompt/dev/ \ ---source_language english \ ---target_language chinese \ ---output_name pred_clone.wav \ ---use_pt_vocoder False \ ---voc pwgan_aishell3 \ ---voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ ---voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ ---voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ ---am fastspeech2_csmsc \ ---am_config download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ ---am_ckpt download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ ---am_stat download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ ---phones_dict download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt \ No newline at end of file + --task_name=cross-lingual_clone \ + --model_name=paddle_checkpoint_dual_mask_enzh \ + --uid=Prompt_003_new \ + --new_str='今天天气很好.' \ + --prefix='./prompt/dev/' \ + --source_language=english \ + --target_language=chinese \ + --output_name=pred_clone.wav \ + --use_pt_vocoder=False \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_csmsc \ + --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ + --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ + --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt \ No newline at end of file diff --git a/ernie-sat/run_gen_en.sh b/ernie-sat/run_gen_en.sh old mode 100644 new mode 100755 index 22af6c5760b9e4f7de1226edf6e3ab9fab7f5785..1f7fbf69068199dce3c526a9693b0c6f7bbbd333 --- a/ernie-sat/run_gen_en.sh +++ b/ernie-sat/run_gen_en.sh @@ -1,22 +1,24 @@ +#!/bin/bash + # 纯英文的语音合成 -# 样例为根据p299_096对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' +# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' python inference.py \ ---task_name synthesize \ ---model_name paddle_checkpoint_en \ ---uid p299_096 \ ---new_str 'I enjoy my life.' \ ---prefix ./prompt/dev/ \ ---source_language english \ ---target_language english \ ---output_name pred_gen.wav \ ---use_pt_vocoder True \ ---voc pwgan_aishell3 \ ---voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ ---voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ ---voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ ---am fastspeech2_ljspeech \ ---am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ ---am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ ---am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ ---phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt \ No newline at end of file + --task_name=synthesize \ + --model_name=paddle_checkpoint_en \ + --uid=p299_096 \ + --new_str='I enjoy my life, do you?' \ + --prefix='./prompt/dev/' \ + --source_language=english \ + --target_language=english \ + --output_name=pred_gen.wav \ + --use_pt_vocoder=False \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt \ No newline at end of file diff --git a/ernie-sat/run_sedit_en.sh b/ernie-sat/run_sedit_en.sh old mode 100644 new mode 100755 index 73aef96d6a185c52424e8f213bda73e41a05ee71..4c2b248b9b3a1803815fe1d8be1903753efb84fb --- a/ernie-sat/run_sedit_en.sh +++ b/ernie-sat/run_sedit_en.sh @@ -1,23 +1,25 @@ +#!/bin/bash + # 纯英文的语音编辑 -# 样例为把p243_new对应的原始语音: For that reason cover should not be given.编辑成'for that reason cover is impossible to be given.'对应的语音 -# NOTE: 语音编辑任务暂支持句子中1个位置的替换或者插入文本操作 +# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音 +# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作 python inference.py \ ---task_name edit \ ---model_name paddle_checkpoint_en \ ---uid p243_new \ ---new_str 'for that reason cover is impossible to be given.' \ ---prefix ./prompt/dev/ \ ---source_language english \ ---target_language english \ ---output_name pred_edit.wav \ ---use_pt_vocoder True \ ---voc pwgan_aishell3 \ ---voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ ---voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ ---voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ ---am fastspeech2_ljspeech \ ---am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ ---am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ ---am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ ---phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt \ No newline at end of file + --task_name=edit \ + --model_name=paddle_checkpoint_en \ + --uid=p243_new \ + --new_str='for that reason cover is impossible to be given.' \ + --prefix='./prompt/dev/' \ + --source_language=english \ + --target_language=english \ + --output_name=pred_edit.wav \ + --use_pt_vocoder=False \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/ernie-sat/sedit_arg_parser.py b/ernie-sat/sedit_arg_parser.py index 0f25c5e96e1b92f86590b641365965b53790b6c1..a8b8f6feec80c6a08655c3a64e95fa816acaa0ca 100644 --- a/ernie-sat/sedit_arg_parser.py +++ b/ernie-sat/sedit_arg_parser.py @@ -86,7 +86,11 @@ def parse_args(): parser.add_argument("--target_language", type=str, help="target language") parser.add_argument("--output_name", type=str, help="output name") parser.add_argument("--task_name", type=str, help="task name") - parser.add_argument("--use_pt_vocoder", default=True, help="use pytorch version vocoder or not. [Note: only in english condition.]") + parser.add_argument( + "--use_pt_vocoder", + type=str2bool, + default=True, + help="use pytorch version vocoder or not. [Note: only in english condition.]") # pre args = parser.parse_args() diff --git a/ernie-sat/test_run.sh b/ernie-sat/test_run.sh new file mode 100755 index 0000000000000000000000000000000000000000..33ef7ea8d03ae2d42eabd6a6b9018166c19f9c5c --- /dev/null +++ b/ernie-sat/test_run.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +rm -rf *.wav +sh run_sedit_en.sh # 语音编辑任务(英文) +sh run_gen_en.sh # 个性化语音合成任务(英文) +sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) \ No newline at end of file diff --git a/ernie-sat/tools/.DS_Store b/ernie-sat/tools/.DS_Store deleted file mode 100644 index b09d23bbc8f57834eab775ed1abb636cb50ed2e1..0000000000000000000000000000000000000000 Binary files a/ernie-sat/tools/.DS_Store and /dev/null differ diff --git a/ernie-sat/tools/__init__.py b/ernie-sat/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py b/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py index 5ac5c48cda81f41501b3ce0c99ecad0c426c16c3..856a0c0e8cb004b9b142ec6a88a6c4bcf235fde8 100644 --- a/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py +++ b/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py @@ -1,28 +1,21 @@ -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - """Wrapper class for the vocoder model trained with parallel_wavegan repo.""" - import logging import os - from pathlib import Path from typing import Optional from typing import Union -import yaml - import torch +import yaml class ParallelWaveGANPretrainedVocoder(torch.nn.Module): """Wrapper class to load the vocoder trained with parallel_wavegan repo.""" def __init__( - self, - model_file: Union[Path, str], - config_file: Optional[Union[Path, str]] = None, - ): + self, + model_file: Union[Path, str], + config_file: Optional[Union[Path, str]]=None, ): """Initialize ParallelWaveGANPretrainedVocoder module.""" super().__init__() try: @@ -30,8 +23,7 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module): except ImportError: logging.error( "`parallel_wavegan` is not installed. " - "Please install via `pip install -U parallel_wavegan`." - ) + "Please install via `pip install -U parallel_wavegan`.") raise if config_file is None: dirname = os.path.dirname(str(model_file)) @@ -59,5 +51,4 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module): """ return self.vocoder.inference( feats, - normalize_before=self.normalize_before, - ).view(-1) + normalize_before=self.normalize_before, ).view(-1) diff --git a/ernie-sat/utils.py b/ernie-sat/utils.py index 921b36562bf27dffea5765f63c57533e5f7158b9..cb1e93c265b9baa1ae04d99b6d870a9f5f7fbbba 100644 --- a/ernie-sat/utils.py +++ b/ernie-sat/utils.py @@ -1,45 +1,15 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import argparse -import logging -from pathlib import Path - -import jsonlines import numpy as np import paddle -import soundfile as sf import yaml -from timer import timer +from sedit_arg_parser import parse_args from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.t2s.exps.syn_utils import get_test_dataset +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_voc_inference -from paddlespeech.t2s.utils import str2bool -from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 -from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference from paddlespeech.t2s.modules.normalizer import ZScore -from yacs.config import CfgNode -# new add -import paddle.nn.functional as F -from paddlespeech.t2s.modules.nets_utils import make_pad_mask -from paddlespeech.t2s.exps.syn_utils import get_frontend - from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder -from sedit_arg_parser import parse_args +# new add model_alias = { # acoustic model @@ -58,9 +28,6 @@ model_alias = { } - - - def is_chinese(ch): if u'\u4e00' <= ch <= u'\u9fff': return True @@ -69,17 +36,15 @@ def is_chinese(ch): def build_vocoder_from_file( - vocoder_config_file = None, - vocoder_file = None, - model = None, - device = "cpu", - ): + vocoder_config_file=None, + vocoder_file=None, + model=None, + device="cpu", ): # Build vocoder if str(vocoder_file).endswith(".pkl"): # If the extension is ".pkl", the model is trained with parallel_wavegan - vocoder = ParallelWaveGANPretrainedVocoder( - vocoder_file, vocoder_config_file - ) + vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file, + vocoder_config_file) return vocoder.to(device) else: @@ -91,7 +56,7 @@ def get_voc_out(mel, target_language="chinese"): args = parse_args() assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..." - + # print("current vocoder: ", args.voc) with open(args.voc_config) as f: voc_config = CfgNode(yaml.safe_load(f)) @@ -106,6 +71,7 @@ def get_voc_out(mel, target_language="chinese"): # print("shepe of wav (time x n_channels):%s"%wav.shape) return np.squeeze(wav) + # dygraph def get_am_inference(args, am_config): with open(args.phones_dict, "r") as f: @@ -159,11 +125,14 @@ def get_am_inference(args, am_config): return am, am_inference, am_name, am_dataset, phn_id -def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300): +def evaluate_durations(phns, + target_language="chinese", + fs=24000, + hop_length=300): args = parse_args() if target_language == 'english': - args.lang='en' + args.lang = 'en' args.am = "fastspeech2_ljspeech" args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" @@ -171,12 +140,12 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" elif target_language == 'chinese': - args.lang='zh' + args.lang = 'zh' args.am = "fastspeech2_csmsc" - args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" - args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" # args = parser.parse_args(args=[]) if args.ngpu == 0: @@ -186,8 +155,6 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 else: print("ngpu should >= 0 !") - - assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..." # Init body. @@ -197,8 +164,8 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 # print(am_config) # print("---------------------") # acoustic model - am, am_inference, am_name, am_dataset,phn_id = get_am_inference(args, am_config) - + am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args, + am_config) torch_phns = phns vocab_phones = {} @@ -206,33 +173,31 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 vocab_phones[tone] = int(id) # print("vocab_phones: ", len(vocab_phones)) vocab_size = len(vocab_phones) - phonemes = [ - phn if phn in vocab_phones else "sp" for phn in torch_phns - ] + phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] phone_ids = [vocab_phones[item] for item in phonemes] phone_ids_new = phone_ids - phone_ids_new.append(vocab_size-1) + phone_ids_new.append(vocab_size - 1) phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) - normalized_mel, d_outs, p_outs, e_outs = am.inference(phone_ids_new, spk_id=None, spk_emb=None) + normalized_mel, d_outs, p_outs, e_outs = am.inference( + phone_ids_new, spk_id=None, spk_emb=None) pre_d_outs = d_outs phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = phoneme_durations_new.tolist()[:-1] return phoneme_durations_new - def sentence2phns(sentence, target_language="en"): args = parse_args() if target_language == 'en': - args.lang='en' + args.lang = 'en' args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" elif target_language == 'zh': - args.lang='zh' - args.phones_dict="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + args.lang = 'zh' + args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" else: print("target_language should in {'zh', 'en'}!") - + frontend = get_frontend(args) merge_sentences = True get_tone_ids = False @@ -246,10 +211,8 @@ def sentence2phns(sentence, target_language="en"): phone_ids = input_ids["phone_ids"] phonemes = frontend.get_phonemes( - sentence, - merge_sentences=merge_sentences, - print_info=False) - + sentence, merge_sentences=merge_sentences, print_info=False) + return phonemes[0], input_ids["phone_ids"][0] elif target_language == 'en': @@ -270,16 +233,11 @@ def sentence2phns(sentence, target_language="en"): phones = [phn for phn in phones if not phn.isspace()] # replace unk phone with sp phones = [ - phn - if (phn in vocab_phones and phn not in punc) else "sp" + phn if (phn in vocab_phones and phn not in punc) else "sp" for phn in phones ] phones_list.append(phones) - return phones_list[0], input_ids["phone_ids"][0] + return phones_list[0], input_ids["phone_ids"][0] else: print("lang should in {'zh', 'en'}!") - - - - diff --git a/ernie-sat/wavs/pred_clone.wav b/ernie-sat/wavs/pred_clone.wav deleted file mode 100644 index 75bc889f4a89b4fab3054af944b7f991da3cb97f..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/pred_clone.wav and /dev/null differ diff --git a/ernie-sat/wavs/pred_edit.wav b/ernie-sat/wavs/pred_edit.wav deleted file mode 100644 index 75ae3e6dbd44ce1df5b973ea1e2ad7253e1b4cb7..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/pred_edit.wav and /dev/null differ diff --git a/ernie-sat/wavs/pred_gen.wav b/ernie-sat/wavs/pred_gen.wav deleted file mode 100644 index 76694ddb938a847a60e464ae9f6477aa08944989..0000000000000000000000000000000000000000 Binary files a/ernie-sat/wavs/pred_gen.wav and /dev/null differ