未验证 提交 437081f7 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #824 from yt605155624/format

[ernie sat]format ernie sat
ernie-sat/.meta/framework.png

306.0 KB | W: | H:

ernie-sat/.meta/framework.png

139.9 KB | W: | H:

ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
  • 2-up
  • Swipe
  • Onion skin
ERNIE-SAT是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架 ## 模型框架
ERNIE-SAT中我们提出了两项创新: ERNIE-SAT 中我们提出了两项创新:
- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 - 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 - 采用语言和语音的联合掩码学习实现了语言和语音的对齐
...@@ -12,14 +12,14 @@ ERNIE-SAT中我们提出了两项创新: ...@@ -12,14 +12,14 @@ ERNIE-SAT中我们提出了两项创新:
### 1.安装飞桨与环境依赖 ### 1.安装飞桨与环境依赖
- 本项目的代码基于 Paddle(version>=2.0) - 本项目的代码基于 Paddle(version>=2.0)
- 本项目开放提供加载torch版本的vocoder的功能 - 本项目开放提供加载 torch 版本的 vocoder 的功能
- torch version>=1.8 - torch version>=1.8
- 安装htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的htk(例如3.4.1)。同时提供[历史版本htk下载地址](https://htk.eng.cam.ac.uk/ftp/software/) - 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 1.注册账号,下载htk - 1.注册账号,下载 htk
- 2.解压htk文件,**放入项目根目录的tools文件夹中, 以htk文件夹名称放入** - 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是3.4.1或者更高版本,需要进入HTKLib/HRec.c文件中, **修改1626行和1650行**, 即把**以下两行的dur<=0 都修改为 dur<0**,如下所示: - 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash ```bash
以htk3.4.1版本举例: 以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0"); (1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
...@@ -28,26 +28,23 @@ ERNIE-SAT中我们提出了两项创新: ...@@ -28,26 +28,23 @@ ERNIE-SAT中我们提出了两项创新:
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 "); (2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 "); 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
``` ```
- 4.**编译**: 详情参见解压后的htk中的README文件(如果未编译, 则无法正常运行) - 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
- 安装ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN项目并且安装相关依赖即可。 - 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。
- 安装其他依赖: **sox, libsndfile** - 安装其他依赖: **sox, libsndfile**
### 2.预训练模型 ### 2.预训练模型
预训练模型ERNIE-SAT的模型如下所示: 预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz) - [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz) - [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz) - [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz)
创建download文件夹,下载上述ERNIE-SAT预训练模型并将其解压: 创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```bash ```bash
mkdir pretrained_model mkdir pretrained_model
cd pretrained_model cd pretrained_model
...@@ -56,13 +53,12 @@ tar -zxvf model-ernie-sat-base-zh.tar.gz ...@@ -56,13 +53,12 @@ tar -zxvf model-ernie-sat-base-zh.tar.gz
tar -zxvf model-ernie-sat-base-en_zh.tar.gz tar -zxvf model-ernie-sat-base-en_zh.tar.gz
``` ```
### 3.下载 ### 3.下载
1. 本项目使用parallel wavegan作为声码器(vocoder): 1. 本项目使用 parallel wavegan 作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) - [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压: 创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash ```bash
mkdir download mkdir download
...@@ -70,11 +66,11 @@ cd download ...@@ -70,11 +66,11 @@ cd download
unzip pwg_aishell3_ckpt_0.5.zip unzip pwg_aishell3_ckpt_0.5.zip
``` ```
2. 本项目使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: 2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用 - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用 - [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的fastspeech2模型并将其解压 下载上述预训练的 fastspeech2 模型并将其解压
```bash ```bash
cd download cd download
...@@ -85,7 +81,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip ...@@ -85,7 +81,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
### 4.推理 ### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若use_pt_vocoder参数设置为False,则英文场景下使用paddle版本的声码器。 注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件: 我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件 - prompt_wav: 提供的音频文件
...@@ -114,19 +110,19 @@ prompt/dev ...@@ -114,19 +110,19 @@ prompt/dev
5. `--lang` 对应模型的语言可以是 `zh``en` 5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。 6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称 7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的id 8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址 10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言 11. ` --source_language` , 源语言
12. ` --target_language` , 目标语言 12. ` --target_language` , 目标语言
13. ` --output_name` , 合成语音名称 13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` use_pt_vocoder`,英文场景下是否使用torch版本的vocoder, 默认情况下为True; 设置为False则在英文场景下使用paddle版本vocoder 15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
运行以下脚本即可进行实验 运行以下脚本即可进行实验
```shell ```shell
sh run_sedit_en.sh # 语音编辑任务(英文) sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文) sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
``` ```
#!/usr/bin/env python #!/usr/bin/env python
""" Usage: """ Usage:
align_english.py wavfile trsfile outwordfile outphonefile align_english.py wavfile trsfile outwordfile outphonefile
""" """
import multiprocessing as mp
import os import os
import sys import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english' MODEL_DIR = 'tools/aligner/english'
HVITE = 'tools/htk/HTKTools/HVite' HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy' HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile): def prep_txt(line, tmpbase, dictfile):
words = [] words = []
line = line.strip() line = line.strip()
...@@ -48,7 +47,8 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -48,7 +47,8 @@ def prep_txt(line, tmpbase, dictfile):
for unk in unk_words: for unk in unk_words:
fwid.write(unk + '\n') fwid.write(unk + '\n')
try: try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + '_unk.phons') os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except: except:
print('english2phoneme error!') print('english2phoneme error!')
sys.exit(1) sys.exit(1)
...@@ -79,7 +79,7 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -79,7 +79,7 @@ def prep_txt(line, tmpbase, dictfile):
seq.append(phons[j].upper()) seq.append(phons[j].upper())
j += 1 j += 1
else: else:
p = phons[j:j+2] p = phons[j:j + 2]
if (p == 'WH'): if (p == 'WH'):
seq.append('W') seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
...@@ -96,6 +96,7 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -96,6 +96,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.write('\n') fw.write('\n')
fw.close() fw.close()
def prep_mlf(txt, tmpbase): def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid: with open(tmpbase + '.mlf', 'w') as fwid:
...@@ -108,6 +109,7 @@ def prep_mlf(txt, tmpbase): ...@@ -108,6 +109,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n') fwid.write('sp\n')
fwid.write('.\n') fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2): def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split() words = fid.readline().strip().split()
...@@ -120,19 +122,20 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -120,19 +122,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = [] times1 = []
times2 = [] times2 = []
while (i < len(lines)): while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]): if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2] phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000 pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1])/1000+125)/10000 pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5): if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]): if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip() wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000 st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1 j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5): while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1 j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000 en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en]) times1.append([wrd, st, en])
i += 1 i += 1
...@@ -151,8 +154,13 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -151,8 +154,13 @@ def gen_res(tmpbase, outfile1, outfile2):
for item in times2: for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path, text_string): def alignment(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
...@@ -160,7 +168,7 @@ def alignment(wav_path, text_string): ...@@ -160,7 +168,7 @@ def alignment(wav_path, text_string):
except: except:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
...@@ -179,14 +187,19 @@ def alignment(wav_path, text_string): ...@@ -179,14 +187,19 @@ def alignment(wav_path, text_string):
#prepare scp #prepare scp
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except: except:
print('HCopy error!') print('HCopy error!')
return None return None
#run alignment #run alignment
try: try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except: except:
print('HVite error!') print('HVite error!')
return None return None
...@@ -207,15 +220,15 @@ def alignment(wav_path, text_string): ...@@ -207,15 +220,15 @@ def alignment(wav_path, text_string):
splited_line = lines[i].strip().split() splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1])/1000+125)/10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line)==5: if len(splited_line) == 5:
current_word = str(index)+'_'+splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn word2phns[current_word] = phn
index+=1 index += 1
elif len(splited_line)==4: elif len(splited_line) == 4:
word2phns[current_word] += ' '+phn word2phns[current_word] += ' ' + phn
i+=1 i += 1
return times2,word2phns return times2, word2phns
#!/usr/bin/env python #!/usr/bin/env python
""" Usage: """ Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile align_mandarin.py wavfile trsfile outwordfile putphonefile
""" """
import multiprocessing as mp
import os import os
import sys import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin' MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite' HVITE = 'tools/htk/HTKTools/HVite'
...@@ -19,7 +17,10 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -19,7 +17,10 @@ def prep_txt(line, tmpbase, dictfile):
words = [] words = []
line = line.strip() line = line.strip()
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ') line = line.replace(pun, ' ')
for wrd in line.split(): for wrd in line.split():
if (wrd[-1] == '-'): if (wrd[-1] == '-'):
...@@ -43,6 +44,7 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -43,6 +44,7 @@ def prep_txt(line, tmpbase, dictfile):
fwid.write('\n') fwid.write('\n')
return unk_words return unk_words
def prep_mlf(txt, tmpbase): def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid: with open(tmpbase + '.mlf', 'w') as fwid:
...@@ -55,6 +57,7 @@ def prep_mlf(txt, tmpbase): ...@@ -55,6 +57,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n') fwid.write('sp\n')
fwid.write('.\n') fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2): def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split() words = fid.readline().strip().split()
...@@ -67,19 +70,20 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -67,19 +70,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = [] times1 = []
times2 = [] times2 = []
while (i < len(lines)): while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]): if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2] phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000 pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1])/1000+125)/10000 pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5): if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]): if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip() wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000 st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1 j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5): while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1 j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000 en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en]) times1.append([wrd, st, en])
i += 1 i += 1
...@@ -99,18 +103,18 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -99,18 +103,18 @@ def gen_res(tmpbase, outfile1, outfile2):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string): def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid()) tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + '.wav remix -') os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except: except:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
...@@ -133,14 +137,19 @@ def alignment_zh(wav_path, text_string): ...@@ -133,14 +137,19 @@ def alignment_zh(wav_path, text_string):
#prepare scp #prepare scp
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except: except:
print('HCopy error!') print('HCopy error!')
return None return None
#run alignment #run alignment
try: try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except: except:
print('HVite error!') print('HVite error!')
...@@ -156,23 +165,22 @@ def alignment_zh(wav_path, text_string): ...@@ -156,23 +165,22 @@ def alignment_zh(wav_path, text_string):
i = 2 i = 2
times2 = [] times2 = []
word2phns = {} word2phns = {}
current_word = '' current_word = ''
index = 0 index = 0
while (i < len(lines)): while (i < len(lines)):
splited_line = lines[i].strip().split() splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1])/1000+125)/10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line)==5: if len(splited_line) == 5:
current_word = str(index)+'_'+splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn word2phns[current_word] = phn
index+=1 index += 1
elif len(splited_line)==4: elif len(splited_line) == 4:
word2phns[current_word] += ' '+phn word2phns[current_word] += ' ' + phn
i+=1 i += 1
return times2,word2phns return times2, word2phns
import math
import paddle
import numpy as np import numpy as np
import math import paddle
def pad_list(xs, pad_value): def pad_list(xs, pad_value):
...@@ -28,23 +26,25 @@ def pad_list(xs, pad_value): ...@@ -28,23 +26,25 @@ def pad_list(xs, pad_value):
""" """
n_batch = len(xs) n_batch = len(xs)
max_len = max(paddle.shape(x)[0] for x in xs) max_len = max(paddle.shape(x)[0] for x in xs)
pad = paddle.full((n_batch, max_len), pad_value, dtype = xs[0].dtype) pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype)
for i in range(n_batch): for i in range(n_batch):
pad[i, : paddle.shape(xs[i])[0]] = xs[i] pad[i, :paddle.shape(xs[i])[0]] = xs[i]
return pad return pad
def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window round = max_len % attention_window
if round != 0: if round != 0:
max_tlen += (attention_window - round) max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0] n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros((n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch): for i in range(n_batch):
text_pad[i, : paddle.shape(text[i])[0]] = text[i] text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else: else:
text_pad = text[:, : max_tlen] text_pad = text[:, :max_tlen]
return text_pad, max_tlen return text_pad, max_tlen
...@@ -139,7 +139,6 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): ...@@ -139,7 +139,6 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if not isinstance(lengths, list): if not isinstance(lengths, list):
lengths = list(lengths) lengths = list(lengths)
# print('lengths', lengths)
bs = int(len(lengths)) bs = int(len(lengths))
if xs is None: if xs is None:
maxlen = int(max(lengths)) maxlen = int(max(lengths))
...@@ -147,10 +146,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): ...@@ -147,10 +146,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
maxlen = paddle.shape(xs)[length_dim] maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = paddle.expand(paddle.unsqueeze(seq_range, 0), (bs, maxlen)) seq_range_expand = paddle.expand(
paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1) seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1)
# print('seq_length_expand', paddle.shape(seq_length_expand))
# print('seq_range_expand', paddle.shape(seq_range_expand))
mask = seq_range_expand >= seq_length_expand mask = seq_range_expand >= seq_length_expand
if xs is not None: if xs is not None:
...@@ -160,16 +158,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): ...@@ -160,16 +158,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
length_dim = len(paddle.shape(xs)) + length_dim length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None) # ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple( ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(len(paddle.shape(xs))) slice(None) if i in (0, length_dim) else None
) for i in range(len(paddle.shape(xs))))
# print('0:', paddle.shape(mask))
# print('1:', paddle.shape(mask[ind]))
# print('2:', paddle.shape(xs))
mask = paddle.expand(mask[ind], paddle.shape(xs)) mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1): def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part. """Make mask tensor containing indices of non-padded part.
...@@ -259,8 +253,14 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1): ...@@ -259,8 +253,14 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
return ~make_pad_mask(lengths, xs, length_dim) return ~make_pad_mask(lengths, xs, length_dim)
def phones_masking(xs_pad,
def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary=None): src_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary=None):
bz, sent_len, _ = paddle.shape(xs_pad) bz, sent_len, _ = paddle.shape(xs_pad)
mask_num_lower = math.ceil(sent_len * mlm_prob) mask_num_lower = math.ceil(sent_len * mlm_prob)
masked_position = np.zeros((bz, sent_len)) masked_position = np.zeros((bz, sent_len))
...@@ -273,38 +273,41 @@ def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths ...@@ -273,38 +273,41 @@ def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths
elif mean_phn_span == 0: elif mean_phn_span == 0:
# only speech # only speech
length = sent_len length = sent_len
mean_phn_span = min(length*mlm_prob//3, 50) mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() masked_phn_indices = random_spans_noise_mask(length, mlm_prob,
masked_position[:,masked_phn_indices]=1 mean_phn_span).nonzero()
masked_position[:, masked_phn_indices] = 1
else: else:
for idx in range(bz): for idx in range(bz):
if span_boundary is not None: if span_boundary is not None:
for s,e in zip(span_boundary[idx][::2], span_boundary[idx][1::2]): for s, e in zip(span_boundary[idx][::2],
span_boundary[idx][1::2]):
masked_position[idx, s:e] = 1 masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0 # y_masks[idx, e:, s:e ] = 0
else: else:
length = align_start_lengths[idx].item() length = align_start_lengths[idx].item()
if length<2: if length < 2:
continue continue
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero() masked_phn_indices = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_indices].tolist() masked_start = align_start[idx][masked_phn_indices].tolist()
masked_end = align_end[idx][masked_phn_indices].tolist() masked_end = align_end[idx][masked_phn_indices].tolist()
for s,e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_position[idx, s:e] = 1 masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0 # y_masks[idx, e:, s:e ] = 0
non_eos_mask = np.array(paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu()) non_eos_mask = np.array(
paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
masked_position = masked_position * non_eos_mask masked_position = masked_position * non_eos_mask
# y_masks = src_mask & y_masks.bool() # y_masks = src_mask & y_masks.bool()
return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks
def get_segment_pos(speech_pad, text_pad, align_start, align_end,
align_start_lengths, sega_emb):
def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb):
bz, speech_len, _ = speech_pad.size() bz, speech_len, _ = speech_pad.size()
_, text_len = text_pad.size() _, text_len = text_pad.size()
...@@ -313,7 +316,6 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le ...@@ -313,7 +316,6 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
text_segment_pos = np.zeros((bz, text_len)).astype('int64') text_segment_pos = np.zeros((bz, text_len)).astype('int64')
speech_segment_pos = np.zeros((bz, speech_len)).astype('int64') speech_segment_pos = np.zeros((bz, speech_len)).astype('int64')
if not sega_emb: if not sega_emb:
text_segment_pos = paddle.to_tensor(text_segment_pos) text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos) speech_segment_pos = paddle.to_tensor(speech_segment_pos)
...@@ -321,11 +323,11 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le ...@@ -321,11 +323,11 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
for idx in range(bz): for idx in range(bz):
align_length = align_start_lengths[idx].item() align_length = align_start_lengths[idx].item()
for j in range(align_length): for j in range(align_length):
s,e = align_start[idx][j].item(), align_end[idx][j].item() s, e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j+1 speech_segment_pos[idx][s:e] = j + 1
text_segment_pos[idx][j] = j+1 text_segment_pos[idx][j] = j + 1
text_segment_pos = paddle.to_tensor(text_segment_pos) text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos) speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos return speech_segment_pos, text_segment_pos
\ No newline at end of file
#!/usr/bin/env python3 #!/usr/bin/env python3
import os
from pathlib import Path
import librosa
import random
import soundfile as sf
import sys
import pickle
import argparse import argparse
import math
import os
import pickle
import random
import string
import sys
from pathlib import Path
from typing import Collection from typing import Collection
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import librosa
import numpy as np
import paddle import paddle
import soundfile as sf
import torch import torch
import math
import string
import numpy as np
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from read_text import read_2column_text,load_num_sequence_text
from utils import sentence2phns,get_voc_out, evaluate_durations, is_chinese, build_vocoder_from_file
from model_paddle import build_model_from_file
from sedit_arg_parser import parse_args
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from dataset import pad_list, pad_to_longformer_att_window, make_pad_mask, make_non_pad_mask, phones_masking, get_segment_pos
from align_english import alignment from align_english import alignment
from align_mandarin import alignment_zh from align_mandarin import alignment_zh
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from dataset import get_segment_pos
from dataset import make_non_pad_mask
from dataset import make_pad_mask
from dataset import pad_list
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from utils import sentence2phns
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin' MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str, use_pt_vocoder, duration_preditor_path,sid=None, non_autoreg=True): def plot_mel_and_vocode_wav(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
use_pt_vocoder,
duration_preditor_path,
sid=None,
non_autoreg=True):
wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output(
uid, uid,
prefix, prefix,
clone_uid, clone_uid,
clone_prefix, clone_prefix,
source_language, source_language,
target_language, target_language,
model_name, model_name,
wav_path, wav_path,
old_str, old_str,
new_str, new_str,
duration_preditor_path, duration_preditor_path,
use_teacher_forcing=non_autoreg, use_teacher_forcing=non_autoreg,
sid=sid sid=sid)
)
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[
1]].detach().float().cpu().numpy()
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[1]].detach().float().cpu().numpy()
if target_language == 'english': if target_language == 'english':
if use_pt_vocoder: if use_pt_vocoder:
output_feat = output_feat.detach().float().cpu().numpy() output_feat = output_feat.detach().float().cpu().numpy()
output_feat = torch.tensor(output_feat,dtype=torch.float) output_feat = torch.tensor(output_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).detach().float().data.cpu().numpy() replaced_wav = vocoder(
output_feat).detach().float().data.cpu().numpy()
else: else:
output_feat_np = output_feat.detach().float().cpu().numpy() output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav = get_voc_out(output_feat_np, target_language) replaced_wav = get_voc_out(output_feat_np, target_language)
elif target_language == 'chinese': elif target_language == 'chinese':
output_feat_np = output_feat.detach().float().cpu().numpy() output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_language) replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat,
target_language)
old_time_boundary = [hop_length * x for x in old_span_boundary]
new_time_boundary = [hop_length * x for x in new_span_boundary] old_time_boundary = [hop_length * x for x in old_span_boundary]
new_time_boundary = [hop_length * x for x in new_span_boundary]
if target_language == 'english': if target_language == 'english':
wav_org_replaced_paddle_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]]) wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_boundary[0]],
replaced_wav[new_time_boundary[0]:new_time_boundary[1]],
wav_org[old_time_boundary[1]:]
])
data_dict = { data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
"origin":wav_org,
"output":wav_org_replaced_paddle_voc}
elif target_language == 'chinese': elif target_language == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[old_time_boundary[1]:]]) wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_boundary[1]:]
])
data_dict = { data_dict = {
"origin":wav_org, "origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,} "output": wav_org_replaced_only_mask_fst2_voc,
}
return data_dict, old_span_boundary
return data_dict, old_span_boundary
def get_unk_phns(word_str): def get_unk_phns(word_str):
...@@ -97,7 +126,8 @@ def get_unk_phns(word_str): ...@@ -97,7 +126,8 @@ def get_unk_phns(word_str):
f = open(tmpbase + 'temp.words', 'w') f = open(tmpbase + 'temp.words', 'w')
f.write(word_str) f.write(word_str)
f.close() f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + 'temp.phons') os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r') f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split() lines2 = f.readline().strip().split()
f.close() f.close()
...@@ -116,7 +146,7 @@ def get_unk_phns(word_str): ...@@ -116,7 +146,7 @@ def get_unk_phns(word_str):
seq.append(phons[j].upper()) seq.append(phons[j].upper())
j += 1 j += 1
else: else:
p = phons[j:j+2] p = phons[j:j + 2]
if (p == 'WH'): if (p == 'WH'):
seq.append('W') seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
...@@ -129,8 +159,9 @@ def get_unk_phns(word_str): ...@@ -129,8 +159,9 @@ def get_unk_phns(word_str):
phns.extend(seq) phns.extend(seq)
return phns return phns
def words2phns(line): def words2phns(line):
dictfile = MODEL_DIR_EN+'/dict' dictfile = MODEL_DIR_EN + '/dict'
tmpbase = '/tmp/tp.' tmpbase = '/tmp/tp.'
line = line.strip() line = line.strip()
words = [] words = []
...@@ -151,30 +182,33 @@ def words2phns(line): ...@@ -151,30 +182,33 @@ def words2phns(line):
ds.add(word) ds.add(word)
if word not in word2phns_dict.keys(): if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:]) word2phns_dict[word] = " ".join(line.split()[1:])
phns = [] phns = []
wrd2phns = {} wrd2phns = {}
for index, wrd in enumerate(words): for index, wrd in enumerate(words):
if wrd == '[MASK]': if wrd == '[MASK]':
wrd2phns[str(index)+"_"+wrd] = [wrd] wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd) phns.append(wrd)
elif (wrd.upper() not in ds): elif (wrd.upper() not in ds):
wrd2phns[str(index)+"_"+wrd.upper()] = get_unk_phns(wrd) wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd)) phns.extend(get_unk_phns(wrd))
else: else:
wrd2phns[str(index)+"_"+wrd.upper()] = word2phns_dict[wrd.upper()].split() wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split()) phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns return phns, wrd2phns
def words2phns_zh(line): def words2phns_zh(line):
dictfile = MODEL_DIR_ZH+'/dict' dictfile = MODEL_DIR_ZH + '/dict'
tmpbase = '/tmp/tp.' tmpbase = '/tmp/tp.'
line = line.strip() line = line.strip()
words = [] words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']: for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ') line = line.replace(pun, ' ')
for wrd in line.split(): for wrd in line.split():
if (wrd[-1] == '-'): if (wrd[-1] == '-'):
...@@ -183,7 +217,7 @@ def words2phns_zh(line): ...@@ -183,7 +217,7 @@ def words2phns_zh(line):
wrd = wrd[1:] wrd = wrd[1:]
if wrd: if wrd:
words.append(wrd) words.append(wrd)
ds = set([]) ds = set([])
word2phns_dict = {} word2phns_dict = {}
with open(dictfile, 'r') as fid: with open(dictfile, 'r') as fid:
...@@ -192,17 +226,17 @@ def words2phns_zh(line): ...@@ -192,17 +226,17 @@ def words2phns_zh(line):
ds.add(word) ds.add(word)
if word not in word2phns_dict.keys(): if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:]) word2phns_dict[word] = " ".join(line.split()[1:])
phns = [] phns = []
wrd2phns = {} wrd2phns = {}
for index, wrd in enumerate(words): for index, wrd in enumerate(words):
if wrd == '[MASK]': if wrd == '[MASK]':
wrd2phns[str(index)+"_"+wrd] = [wrd] wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd) phns.append(wrd)
elif (wrd.upper() not in ds): elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...") print("出现非法词错误,请输入正确的文本...")
else: else:
wrd2phns[str(index)+"_"+wrd] = word2phns_dict[wrd].split() wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split()) phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns return phns, wrd2phns
...@@ -212,62 +246,67 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): ...@@ -212,62 +246,67 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag) vocoder_file = download_pretrained_model(vocoder_tag)
vocoder_config = Path(vocoder_file).parent / "config.yml" vocoder_config = Path(vocoder_file).parent / "config.yml"
vocoder = build_vocoder_from_file( vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
vocoder_config, vocoder_file, None, 'cpu'
)
return vocoder return vocoder
def load_model(model_name): def load_model(model_name):
config_path='./pretrained_model/{}/config.yaml'.format(model_name) config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file(config_file=config_path, mlm_model, args = build_model_from_file(
model_file=model_path) config_file=config_path, model_file=model_path)
return mlm_model, args return mlm_model, args
def read_data(uid,prefix): def read_data(uid, prefix):
mfa_text = read_2column_text(prefix+'/text')[uid] mfa_text = read_2column_text(prefix + '/text')[uid]
mfa_wav_path = read_2column_text(prefix+'/wav.scp')[uid] mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
if 'mnt' not in mfa_wav_path: if 'mnt' not in mfa_wav_path:
mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path
return mfa_text, mfa_wav_path return mfa_text, mfa_wav_path
def get_align_data(uid,prefix):
mfa_path = prefix+"mfa_" def get_align_data(uid, prefix):
mfa_text = read_2column_text(mfa_path+'text')[uid] mfa_path = prefix + "mfa_"
mfa_start = load_num_sequence_text(mfa_path+'start',loader_type='text_float')[uid] mfa_text = read_2column_text(mfa_path + 'text')[uid]
mfa_end = load_num_sequence_text(mfa_path+'end',loader_type='text_float')[uid] mfa_start = load_num_sequence_text(
mfa_wav_path = read_2column_text(mfa_path+'wav.scp')[uid] mfa_path + 'start', loader_type='text_float')[uid]
mfa_end = load_num_sequence_text(
mfa_path + 'end', loader_type='text_float')[uid]
mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid]
return mfa_text, mfa_start, mfa_end, mfa_wav_path return mfa_text, mfa_start, mfa_end, mfa_wav_path
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced): def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length,
align_start=paddle.to_tensor(mfa_start).unsqueeze(0) span_tobe_replaced):
align_end =paddle.to_tensor(mfa_end).unsqueeze(0) align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
align_start = paddle.floor(fs*align_start/hop_length).int() align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
align_end = paddle.floor(fs*align_end/hop_length).int() align_start = paddle.floor(fs * align_start / hop_length).int()
if span_tobe_replaced[0]>=len(mfa_start): align_end = paddle.floor(fs * align_end / hop_length).int()
span_boundary = [align_end[0].tolist()[-1],align_end[0].tolist()[-1]] if span_tobe_replaced[0] >= len(mfa_start):
span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
else: else:
span_boundary=[align_start[0].tolist()[span_tobe_replaced[0]],align_end[0].tolist()[span_tobe_replaced[1]-1]] span_boundary = [
align_start[0].tolist()[span_tobe_replaced[0]],
align_end[0].tolist()[span_tobe_replaced[1] - 1]
]
return span_boundary return span_boundary
def recover_dict(word2phns, tp_word2phns): def recover_dict(word2phns, tp_word2phns):
dic = {} dic = {}
need_del_key = [] need_del_key = []
exist_index = [] exist_index = []
sp_count = 0 sp_count = 0
add_sp_count = 0 add_sp_count = 0
for key in word2phns.keys(): for key in word2phns.keys():
idx, wrd = key.split('_') idx, wrd = key.split('_')
if wrd == 'sp': if wrd == 'sp':
sp_count += 1 sp_count += 1
exist_index.append(int(idx)) exist_index.append(int(idx))
else: else:
need_del_key.append(key) need_del_key.append(key)
for key in need_del_key: for key in need_del_key:
del word2phns[key] del word2phns[key]
...@@ -275,35 +314,36 @@ def recover_dict(word2phns, tp_word2phns): ...@@ -275,35 +314,36 @@ def recover_dict(word2phns, tp_word2phns):
for key in tp_word2phns.keys(): for key in tp_word2phns.keys():
# print("debug: ",key) # print("debug: ",key)
if cur_id in exist_index: if cur_id in exist_index:
dic[str(cur_id)+"_sp"] = 'sp' dic[str(cur_id) + "_sp"] = 'sp'
cur_id += 1 cur_id += 1
add_sp_count += 1 add_sp_count += 1
idx, wrd = key.split('_') idx, wrd = key.split('_')
dic[str(cur_id)+"_"+wrd] = tp_word2phns[key] dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
cur_id += 1 cur_id += 1
if add_sp_count + 1 == sp_count: if add_sp_count + 1 == sp_count:
dic[str(cur_id)+"_sp"] = 'sp' dic[str(cur_id) + "_sp"] = 'sp'
add_sp_count += 1 add_sp_count += 1
assert add_sp_count == sp_count, "sp are not added in dic" assert add_sp_count == sp_count, "sp are not added in dic"
return dic return dic
def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target_language): def get_phns_and_spans(wav_path, old_str, new_str, source_language,
clone_target_language):
append_new_str = (old_str == new_str[:len(old_str)]) append_new_str = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], [] old_phns, mfa_start, mfa_end = [], [], []
if source_language == "english": if source_language == "english":
times2,word2phns = alignment(wav_path, old_str) times2, word2phns = alignment(wav_path, old_str)
elif source_language == "chinese": elif source_language == "chinese":
times2,word2phns = alignment_zh(wav_path, old_str) times2, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str) _, tp_word2phns = words2phns_zh(old_str)
for key,value in tp_word2phns.items(): for key, value in tp_word2phns.items():
idx, wrd = key.split('_') idx, wrd = key.split('_')
cur_val = " ".join(value) cur_val = " ".join(value)
tp_word2phns[key] = cur_val tp_word2phns[key] = cur_val
word2phns = recover_dict(word2phns, tp_word2phns) word2phns = recover_dict(word2phns, tp_word2phns)
...@@ -315,9 +355,8 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target ...@@ -315,9 +355,8 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
mfa_end.append(float(item[2])) mfa_end.append(float(item[2]))
old_phns.append(item[0]) old_phns.append(item[0])
if append_new_str and (source_language != clone_target_language): if append_new_str and (source_language != clone_target_language):
is_cross_lingual_clone = True is_cross_lingual_clone = True
else: else:
is_cross_lingual_clone = False is_cross_lingual_clone = False
...@@ -326,54 +365,59 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target ...@@ -326,54 +365,59 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
new_str_append = new_str[len(old_str):] new_str_append = new_str[len(old_str):]
if clone_target_language == "chinese": if clone_target_language == "chinese":
new_phns_origin,new_origin_word2phns = words2phns(new_str_origin) new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
new_phns_append,temp_new_append_word2phns = words2phns_zh(new_str_append) new_phns_append, temp_new_append_word2phns = words2phns_zh(
new_str_append)
elif clone_target_language == "english": elif clone_target_language == "english":
new_phns_origin,new_origin_word2phns = words2phns_zh(new_str_origin) # 原始句子 new_phns_origin, new_origin_word2phns = words2phns_zh(
new_phns_append,temp_new_append_word2phns = words2phns(new_str_append) # clone句子 new_str_origin) # 原始句子
new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append) # clone句子
else: else:
assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it." assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append new_phns = new_phns_origin + new_phns_append
new_append_word2phns = {} new_append_word2phns = {}
length = len(new_origin_word2phns) length = len(new_origin_word2phns)
for key,value in temp_new_append_word2phns.items(): for key, value in temp_new_append_word2phns.items():
idx, wrd = key.split('_') idx, wrd = key.split('_')
new_append_word2phns[str(int(idx)+length)+'_'+wrd] = value new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value
new_word2phns = dict(list(new_origin_word2phns.items()) + list(new_append_word2phns.items()))
else: new_word2phns = dict(
list(new_origin_word2phns.items()) + list(
new_append_word2phns.items()))
else:
if source_language == clone_target_language and clone_target_language == "english": if source_language == clone_target_language and clone_target_language == "english":
new_phns, new_word2phns = words2phns(new_str) new_phns, new_word2phns = words2phns(new_str)
elif source_language == clone_target_language and clone_target_language == "chinese": elif source_language == clone_target_language and clone_target_language == "chinese":
new_phns, new_word2phns = words2phns_zh(new_str) new_phns, new_word2phns = words2phns_zh(new_str)
else: else:
assert source_language == clone_target_language, "source language is not same with target language..." assert source_language == clone_target_language, "source language is not same with target language..."
span_tobe_replaced = [0,len(old_phns)-1] span_tobe_replaced = [0, len(old_phns) - 1]
span_tobe_added = [0,len(new_phns)-1] span_tobe_added = [0, len(new_phns) - 1]
left_index = 0 left_index = 0
new_phns_left = [] new_phns_left = []
sp_count = 0 sp_count = 0
# find the left different index # find the left different index
for key in word2phns.keys(): for key in word2phns.keys():
idx, wrd = key.split('_') idx, wrd = key.split('_')
if wrd=='sp': if wrd == 'sp':
sp_count +=1 sp_count += 1
new_phns_left.append('sp') new_phns_left.append('sp')
else: else:
idx = str(int(idx) - sp_count) idx = str(int(idx) - sp_count)
if idx+'_'+wrd in new_word2phns: if idx + '_' + wrd in new_word2phns:
left_index+=len(new_word2phns[idx+'_'+wrd]) left_index += len(new_word2phns[idx + '_' + wrd])
new_phns_left.extend(word2phns[key].split()) new_phns_left.extend(word2phns[key].split())
else: else:
span_tobe_replaced[0] = len(new_phns_left) span_tobe_replaced[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left) span_tobe_added[0] = len(new_phns_left)
break break
# reverse word2phns and new_word2phns # reverse word2phns and new_word2phns
right_index = 0 right_index = 0
new_phns_right = [] new_phns_right = []
...@@ -381,7 +425,7 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target ...@@ -381,7 +425,7 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0]) word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0]) new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0])
new_phns_middle = [] new_phns_middle = []
if append_new_str: if append_new_str:
new_phns_right = [] new_phns_right = []
new_phns_middle = new_phns[left_index:] new_phns_middle = new_phns[left_index:]
span_tobe_replaced[0] = len(new_phns_left) span_tobe_replaced[0] = len(new_phns_left)
...@@ -391,176 +435,306 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target ...@@ -391,176 +435,306 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
else: else:
for key in list(word2phns.keys())[::-1]: for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_') idx, wrd = key.split('_')
if wrd=='sp': if wrd == 'sp':
sp_count +=1 sp_count += 1
new_phns_right = ['sp']+new_phns_right new_phns_right = ['sp'] + new_phns_right
else: else:
idx = str(new_word2phns_max_index-(word2phns_max_index-int(idx)-sp_count)) idx = str(new_word2phns_max_index - (word2phns_max_index - int(
if idx+'_'+wrd in new_word2phns: idx) - sp_count))
right_index-=len(new_word2phns[idx+'_'+wrd]) if idx + '_' + wrd in new_word2phns:
right_index -= len(new_word2phns[idx + '_' + wrd])
new_phns_right = word2phns[key].split() + new_phns_right new_phns_right = word2phns[key].split() + new_phns_right
else: else:
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
new_phns_middle = new_phns[left_index:right_index] new_phns_middle = new_phns[left_index:right_index]
span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) span_tobe_added[1] = len(new_phns_left) + len(
new_phns_middle)
if len(new_phns_middle) == 0: if len(new_phns_middle) == 0:
span_tobe_added[1] = min(span_tobe_added[1]+1, len(new_phns)) span_tobe_added[1] = min(span_tobe_added[1] + 1,
span_tobe_added[0] = max(0, span_tobe_added[0]-1) len(new_phns))
span_tobe_replaced[0] = max(0, span_tobe_replaced[0]-1) span_tobe_added[0] = max(0, span_tobe_added[0] - 1)
span_tobe_replaced[1] = min(span_tobe_replaced[1]+1, len(old_phns)) span_tobe_replaced[0] = max(0,
span_tobe_replaced[0] - 1)
span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1,
len(old_phns))
break break
new_phns = new_phns_left+new_phns_middle+new_phns_right new_phns = new_phns_left + new_phns_middle + new_phns_right
return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added
def duration_adjust_factor(original_dur, pred_dur, phns): def duration_adjust_factor(original_dur, pred_dur, phns):
length = 0 length = 0
accumulate = 0 accumulate = 0
factor_list = [] factor_list = []
for ori,pred,phn in zip(original_dur, pred_dur,phns): for ori, pred, phn in zip(original_dur, pred_dur, phns):
if pred==0 or phn=='sp': if pred == 0 or phn == 'sp':
continue continue
else: else:
factor_list.append(ori/pred) factor_list.append(ori / pred)
factor_list = np.array(factor_list) factor_list = np.array(factor_list)
factor_list.sort() factor_list.sort()
if len(factor_list)<5: if len(factor_list) < 5:
return 1 return 1
length = 2 length = 2
return np.average(factor_list[length:-length]) return np.average(factor_list[length:-length])
def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, new_str, wav_path,duration_preditor_path,sid=None, mask_reconstruct=False,duration_adjust=True,start_end_sp=False, train_args=None):
wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) def prepare_features_with_duration(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
sid=None,
mask_reconstruct=False,
duration_adjust=True,
start_end_sp=False,
train_args=None):
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
fs = train_args.feats_extract_conf['fs'] fs = train_args.feats_extract_conf['fs']
hop_length = train_args.feats_extract_conf['hop_length'] hop_length = train_args.feats_extract_conf['hop_length']
mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(wav_path, old_str, new_str, source_language, target_language) mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(
wav_path, old_str, new_str, source_language, target_language)
if start_end_sp: if start_end_sp:
if new_phns[-1]!='sp': if new_phns[-1] != 'sp':
new_phns = new_phns+['sp'] new_phns = new_phns + ['sp']
if target_language == "english": if target_language == "english":
old_durations = evaluate_durations(old_phns, target_language=target_language) old_durations = evaluate_durations(
old_phns, target_language=target_language)
elif target_language =="chinese": elif target_language == "chinese":
if source_language == "english": if source_language == "english":
old_durations = evaluate_durations(old_phns, target_language=source_language) old_durations = evaluate_durations(
old_phns, target_language=source_language)
elif source_language == "chinese": elif source_language == "chinese":
old_durations = evaluate_durations(old_phns, target_language=source_language) old_durations = evaluate_durations(
old_phns, target_language=source_language)
else: else:
assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..."
original_old_durations = [e-s for e,s in zip(mfa_end, mfa_start)] original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str: if '[MASK]' in new_str:
new_phns = old_phns new_phns = old_phns
span_tobe_added = span_tobe_replaced span_tobe_added = span_tobe_replaced
d_factor_left = duration_adjust_factor(original_old_durations[:span_tobe_replaced[0]],old_durations[:span_tobe_replaced[0]], old_phns[:span_tobe_replaced[0]]) d_factor_left = duration_adjust_factor(
d_factor_right = duration_adjust_factor(original_old_durations[span_tobe_replaced[1]:],old_durations[span_tobe_replaced[1]:], old_phns[span_tobe_replaced[1]:]) original_old_durations[:span_tobe_replaced[0]],
d_factor = (d_factor_left+d_factor_right)/2 old_durations[:span_tobe_replaced[0]],
new_durations_adjusted = [d_factor*i for i in old_durations] old_phns[:span_tobe_replaced[0]])
d_factor_right = duration_adjust_factor(
original_old_durations[span_tobe_replaced[1]:],
old_durations[span_tobe_replaced[1]:],
old_phns[span_tobe_replaced[1]:])
d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations]
else: else:
if duration_adjust: if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations,old_durations, old_phns) d_factor = duration_adjust_factor(original_old_durations,
d_factor_paddle = duration_adjust_factor(original_old_durations,old_durations, old_phns) old_durations, old_phns)
d_factor = d_factor * 1.25 d_factor_paddle = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor = d_factor * 1.25
else: else:
d_factor = 1 d_factor = 1
if target_language == "english":
new_durations = evaluate_durations(new_phns, target_language=target_language)
if target_language == "english":
new_durations = evaluate_durations(
new_phns, target_language=target_language)
elif target_language =="chinese": elif target_language == "chinese":
new_durations = evaluate_durations(new_phns, target_language=target_language) new_durations = evaluate_durations(
new_phns, target_language=target_language)
new_durations_adjusted = [d_factor*i for i in new_durations] new_durations_adjusted = [d_factor * i for i in new_durations]
if span_tobe_replaced[0]<len(old_phns) and old_phns[span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]: if span_tobe_replaced[0] < len(old_phns) and old_phns[
new_durations_adjusted[span_tobe_added[0]] = original_old_durations[span_tobe_replaced[0]] span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]:
if span_tobe_replaced[1]<len(old_phns) and span_tobe_added[1]<len(new_phns): new_durations_adjusted[span_tobe_added[0]] = original_old_durations[
span_tobe_replaced[0]]
if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len(
new_phns):
if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]: if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]:
new_durations_adjusted[span_tobe_added[1]] = original_old_durations[span_tobe_replaced[1]] new_durations_adjusted[span_tobe_added[
new_span_duration_sum = sum(new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]) 1]] = original_old_durations[span_tobe_replaced[1]]
old_span_duration_sum = sum(original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]]) new_span_duration_sum = sum(
duration_offset = new_span_duration_sum - old_span_duration_sum new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]])
old_span_duration_sum = sum(
original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
new_mfa_start = mfa_start[:span_tobe_replaced[0]] new_mfa_start = mfa_start[:span_tobe_replaced[0]]
new_mfa_end = mfa_end[:span_tobe_replaced[0]] new_mfa_end = mfa_end[:span_tobe_replaced[0]]
for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]: for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]:
if len(new_mfa_end) ==0: if len(new_mfa_end) == 0:
new_mfa_start.append(0) new_mfa_start.append(0)
new_mfa_end.append(i) new_mfa_end.append(i)
else: else:
new_mfa_start.append(new_mfa_end[-1]) new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1]+i) new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [i+duration_offset for i in mfa_start[span_tobe_replaced[1]:]] new_mfa_start += [
new_mfa_end += [i+duration_offset for i in mfa_end[span_tobe_replaced[1]:]] i + duration_offset for i in mfa_start[span_tobe_replaced[1]:]
]
new_mfa_end += [
i + duration_offset for i in mfa_end[span_tobe_replaced[1]:]
]
# 3. get new wav # 3. get new wav
if span_tobe_replaced[0]>=len(mfa_start): if span_tobe_replaced[0] >= len(mfa_start):
left_index = len(wav_org) left_index = len(wav_org)
right_index = left_index right_index = left_index
else: else:
left_index = int(np.floor(mfa_start[span_tobe_replaced[0]]*fs)) left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs))
right_index = int(np.ceil(mfa_end[span_tobe_replaced[1]-1]*fs)) right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs))
new_blank_wav = np.zeros((int(np.ceil(new_span_duration_sum*fs)),), dtype=wav_org.dtype) new_blank_wav = np.zeros(
new_wav_org = np.concatenate([wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate(
[wav_org[:left_index], new_blank_wav, wav_org[right_index:]])
# 4. get old and new mel span to be mask # 4. get old and new mel span to be mask
old_span_boundary = get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] old_span_boundary = get_masked_mel_boundary(
new_span_boundary=get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, hop_length, span_tobe_added) # [92, 174] mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92]
new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs,
hop_length,
span_tobe_added) # [92, 174]
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary
def prepare_features(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor, wav_path, old_str,new_str,duration_preditor_path, sid=None,duration_adjust=True,start_end_sp=False,
mask_reconstruct=False, train_args=None): def prepare_features(uid,
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, prefix,
new_str, wav_path,duration_preditor_path,sid=sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp,mask_reconstruct=mask_reconstruct, train_args = train_args) clone_uid,
speech = np.array(wav_org,dtype=np.float32) clone_prefix,
align_start=np.array(mfa_start) source_language,
align_end =np.array(mfa_end) target_language,
mlm_model,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
duration_adjust=True,
start_end_sp=False,
mask_reconstruct=False,
train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct,
train_args=train_args)
speech = np.array(wav_org, dtype=np.float32)
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)} token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array(list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list))) text = np.array(
list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
# print('unk id is', token_to_id['<unk>']) # print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_boundary = np.array(new_span_boundary) span_boundary = np.array(new_span_boundary)
batch=[('1', {"speech":speech,"align_start":align_start,"align_end":align_end,"text":text,"span_boundary":span_boundary})] batch = [('1', {
"speech": speech,
"align_start": align_start,
"align_end": align_end,
"text": text,
"span_boundary": span_boundary
})]
return batch, old_span_boundary, new_span_boundary return batch, old_span_boundary, new_span_boundary
def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False,duration_adjust=True,start_end_sp=False, train_args=None):
fs, hop_length = train_args.feats_extract_conf['fs'], train_args.feats_extract_conf['hop_length']
batch,old_span_boundary,new_span_boundary = prepare_features(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor,wav_path,old_str,new_str,duration_preditor_path, sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args=train_args) def decode_with_model(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
duration_adjust=True,
start_end_sp=False,
train_args=None):
fs, hop_length = train_args.feats_extract_conf[
'fs'], train_args.feats_extract_conf['hop_length']
batch, old_span_boundary, new_span_boundary = prepare_features(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
feats = collate_fn(batch)[1] feats = collate_fn(batch)[1]
if 'text_masked_position' in feats.keys(): if 'text_masked_position' in feats.keys():
feats.pop('text_masked_position') feats.pop('text_masked_position')
for k, v in feats.items(): for k, v in feats.items():
feats[k] = paddle.to_tensor(v) feats[k] = paddle.to_tensor(v)
rtn = mlm_model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=use_teacher_forcing) rtn = mlm_model.inference(
output = rtn['feat_gen'] **feats,
span_boundary=new_span_boundary,
use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen']
if 0 in output[0].shape and 0 not in output[-1].shape: if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat(output[1:-1]+[output[-1].squeeze()], axis=0).cpu() output_feat = paddle.concat(
output[1:-1] + [output[-1].squeeze()], axis=0).cpu()
elif 0 not in output[0].shape and 0 in output[-1].shape: elif 0 not in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat([output[0].squeeze()]+output[1:-1], axis=0).cpu() output_feat = paddle.concat(
[output[0].squeeze()] + output[1:-1], axis=0).cpu()
elif 0 in output[0].shape and 0 in output[-1].shape: elif 0 in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(output[1:-1], axis=0).cpu() output_feat = paddle.concat(output[1:-1], axis=0).cpu()
else: else:
output_feat = paddle.concat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], axis=0).cpu() output_feat = paddle.concat(
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs']) axis=0).cpu()
origin_speech = paddle.to_tensor(np.array(wav_org,dtype=np.float32)).unsqueeze(0)
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
origin_speech = paddle.to_tensor(
np.array(wav_org, dtype=np.float32)).unsqueeze(0)
speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0) speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0)
return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length
...@@ -568,71 +742,64 @@ def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, tar ...@@ -568,71 +742,64 @@ def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, tar
class MLMCollateFn: class MLMCollateFn:
"""Functor class of common_collate_fn()""" """Functor class of common_collate_fn()"""
def __init__( def __init__(self,
self, feats_extract,
feats_extract, float_pad_value: Union[float, int]=0.0,
float_pad_value: Union[float, int] = 0.0, int_pad_value: int=-32768,
int_pad_value: int = -32768, not_sequence: Collection[str]=(),
not_sequence: Collection[str] = (), mlm_prob: float=0.8,
mlm_prob: float=0.8, mean_phn_span: int=8,
mean_phn_span: int=8, attention_window: int=0,
attention_window: int=0, pad_speech: bool=False,
pad_speech: bool=False, sega_emb: bool=False,
sega_emb: bool=False, duration_collect: bool=False,
duration_collect: bool=False, text_masking: bool=False):
text_masking: bool=False self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
):
self.mlm_prob=mlm_prob
self.mean_phn_span=mean_phn_span
self.feats_extract = feats_extract self.feats_extract = feats_extract
self.float_pad_value = float_pad_value self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence) self.not_sequence = set(not_sequence)
self.attention_window=attention_window self.attention_window = attention_window
self.pad_speech=pad_speech self.pad_speech = pad_speech
self.sega_emb=sega_emb self.sega_emb = sega_emb
self.duration_collect = duration_collect self.duration_collect = duration_collect
self.text_masking = text_masking self.text_masking = text_masking
def __repr__(self): def __repr__(self):
return ( return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"{self.__class__}(float_pad_value={self.float_pad_value}, " f"int_pad_value={self.float_pad_value})")
f"int_pad_value={self.float_pad_value})"
) def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn( return mlm_collate_fn(
data, data,
float_pad_value=self.float_pad_value, float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value, int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence, not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob, mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span, mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract, feats_extract=self.feats_extract,
attention_window=self.attention_window, attention_window=self.attention_window,
pad_speech=self.pad_speech, pad_speech=self.pad_speech,
sega_emb=self.sega_emb, sega_emb=self.sega_emb,
duration_collect=self.duration_collect, duration_collect=self.duration_collect,
text_masking=self.text_masking text_masking=self.text_masking)
)
def mlm_collate_fn( def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]], data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0, float_pad_value: Union[float, int]=0.0,
int_pad_value: int = -32768, int_pad_value: int=-32768,
not_sequence: Collection[str] = (), not_sequence: Collection[str]=(),
mlm_prob: float = 0.8, mlm_prob: float=0.8,
mean_phn_span: int = 8, mean_phn_span: int=8,
feats_extract=None, feats_extract=None,
attention_window: int = 0, attention_window: int=0,
pad_speech: bool=False, pad_speech: bool=False,
sega_emb: bool=False, sega_emb: bool=False,
duration_collect: bool=False, duration_collect: bool=False,
text_masking: bool=False text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor. """Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples: Examples:
...@@ -654,9 +821,8 @@ def mlm_collate_fn( ...@@ -654,9 +821,8 @@ def mlm_collate_fn(
data = [d for _, d in data] data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all( assert all(not k.endswith("_lengths")
not k.endswith("_lengths") for k in data[0] for k in data[0]), f"*_lengths is reserved: {list(data[0])}"
), f"*_lengths is reserved: {list(data[0])}"
output = {} output = {}
for key in data[0]: for key in data[0]:
...@@ -679,7 +845,8 @@ def mlm_collate_fn( ...@@ -679,7 +845,8 @@ def mlm_collate_fn(
# lens: (Batch,) # lens: (Batch,)
if key not in not_sequence: if key not in not_sequence:
lens = paddle.to_tensor([d[key].shape[0] for d in data], dtype=paddle.long) lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.long)
output[key + "_lengths"] = lens output[key + "_lengths"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
...@@ -689,71 +856,73 @@ def mlm_collate_fn( ...@@ -689,71 +856,73 @@ def mlm_collate_fn(
feats = paddle.unsqueeze(feats, 0) feats = paddle.unsqueeze(feats, 0)
batch_size = paddle.shape(feats)[0] batch_size = paddle.shape(feats)[0]
if 'text' not in output: if 'text' not in output:
text=paddle.zeros_like(feats_lengths.unsqueeze(-1))-2 text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2
text_lengths=paddle.zeros_like(feats_lengths)+1 text_lengths = paddle.zeros_like(feats_lengths) + 1
max_tlen=1 max_tlen = 1
align_start=paddle.zeros_like(text) align_start = paddle.zeros_like(text)
align_end=paddle.zeros_like(text) align_end = paddle.zeros_like(text)
align_start_lengths=paddle.zeros_like(feats_lengths) align_start_lengths = paddle.zeros_like(feats_lengths)
align_end_lengths=paddle.zeros_like(feats_lengths) align_end_lengths = paddle.zeros_like(feats_lengths)
sega_emb=False sega_emb = False
mean_phn_span = 0 mean_phn_span = 0
mlm_prob = 0.15 mlm_prob = 0.15
else: else:
text, text_lengths = output["text"], output["text_lengths"] text, text_lengths = output["text"], output["text_lengths"]
align_start, align_start_lengths, align_end, align_end_lengths = output["align_start"], output["align_start_lengths"], output["align_end"], output["align_end_lengths"] align_start, align_start_lengths, align_end, align_end_lengths = output[
align_start = paddle.floor(feats_extract.sr*align_start/feats_extract.hop_length).int() "align_start"], output["align_start_lengths"], output[
align_end = paddle.floor(feats_extract.sr*align_end/feats_extract.hop_length).int() "align_end"], output["align_end_lengths"]
align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int()
max_tlen = max(text_lengths).item() max_tlen = max(text_lengths).item()
max_slen = max(feats_lengths).item() max_slen = max(feats_lengths).item()
speech_pad = feats[:, : max_slen] speech_pad = feats[:, :max_slen]
if attention_window>0 and pad_speech: if attention_window > 0 and pad_speech:
speech_pad,max_slen = pad_to_longformer_att_window(speech_pad, max_slen, max_slen, attention_window) speech_pad, max_slen = pad_to_longformer_att_window(
speech_pad, max_slen, max_slen, attention_window)
max_len = max_slen + max_tlen max_len = max_slen + max_tlen
if attention_window>0: if attention_window > 0:
text_pad, max_tlen = pad_to_longformer_att_window(text, max_len, max_tlen, attention_window) text_pad, max_tlen = pad_to_longformer_att_window(
text, max_len, max_tlen, attention_window)
else: else:
text_pad = text text_pad = text
text_mask = make_non_pad_mask(text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) text_mask = make_non_pad_mask(
if attention_window>0: text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2)
text_mask = text_mask*2 if attention_window > 0:
speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:,0], length_dim=1).unsqueeze(-2) text_mask = text_mask * 2
speech_mask = make_non_pad_mask(
feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_boundary = None span_boundary = None
if 'span_boundary' in output.keys(): if 'span_boundary' in output.keys():
span_boundary = output['span_boundary'] span_boundary = output['span_boundary']
if text_masking: if text_masking:
masked_position, text_masked_position,_ = phones_text_masking( masked_position, text_masked_position, _ = phones_text_masking(
speech_pad, speech_pad, speech_mask, text_pad, text_mask, align_start,
speech_mask, align_end, align_start_lengths, mlm_prob, mean_phn_span,
text_pad,
text_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary) span_boundary)
else: else:
text_masked_position = np.zeros(text_pad.size()) text_masked_position = np.zeros(text_pad.size())
masked_position, _ = phones_masking( masked_position, _ = phones_masking(
speech_pad, speech_pad, speech_mask, align_start, align_end,
speech_mask, align_start_lengths, mlm_prob, mean_phn_span, span_boundary)
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary)
output_dict = {} output_dict = {}
if duration_collect and 'text' in output: if duration_collect and 'text' in output:
reordered_index, speech_segment_pos,text_segment_pos, durations,feats_lengths = get_segment_pos_reduce_duration(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb, masked_position, feats_lengths) reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration(
speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:reordered_index.shape[1],0], length_dim=1).unsqueeze(-2) speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb, masked_position, feats_lengths)
speech_mask = make_non_pad_mask(
feats_lengths.tolist(),
speech_pad[:, :reordered_index.shape[1], 0],
length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations output_dict['durations'] = durations
output_dict['reordered_index'] = reordered_index output_dict['reordered_index'] = reordered_index
else: else:
speech_segment_pos, text_segment_pos = get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb) speech_segment_pos, text_segment_pos = get_segment_pos(
speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb)
output_dict['speech'] = speech_pad output_dict['speech'] = speech_pad
output_dict['text'] = text_pad output_dict['text'] = text_pad
output_dict['masked_position'] = masked_position output_dict['masked_position'] = masked_position
...@@ -767,9 +936,8 @@ def mlm_collate_fn( ...@@ -767,9 +936,8 @@ def mlm_collate_fn(
output = (uttids, output_dict) output = (uttids, output_dict)
return output return output
def build_collate_fn(
args: argparse.Namespace, train: bool, epoch=-1 def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
):
# -> Callable[ # -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]], # [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, torch.Tensor]], # Tuple[List[str], Dict[str, torch.Tensor]],
...@@ -793,68 +961,142 @@ def build_collate_fn( ...@@ -793,68 +961,142 @@ def build_collate_fn(
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
if args.encoder_conf['selfattention_layer_type'] == 'longformer': if args.encoder_conf['selfattention_layer_type'] == 'longformer':
attention_window = args.encoder_conf['attention_window'] attention_window = args.encoder_conf['attention_window']
pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf['pre_speech_layer'] >0 else False pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
'pre_speech_layer'] > 0 else False
else: else:
attention_window=0 attention_window = 0
pad_speech=False pad_speech = False
if epoch==-1: if epoch == -1:
mlm_prob_factor = 1 mlm_prob_factor = 1
else: else:
mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5] mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5]
mlm_prob_factor = 0.8 #mlm_probs[epoch // 100] mlm_prob_factor = 0.8 #mlm_probs[epoch // 100]
if 'duration_predictor_layers' in args.model_conf.keys() and args.model_conf['duration_predictor_layers']>0: if 'duration_predictor_layers' in args.model_conf.keys(
duration_collect=True ) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True
else: else:
duration_collect=False duration_collect = False
return MLMCollateFn(feats_extract, float_pad_value=0.0, int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob']*mlm_prob_factor,mean_phn_span=args.model_conf['mean_phn_span'],attention_window=attention_window,pad_speech=pad_speech,sega_emb=sega_emb,duration_collect=duration_collect)
def get_mlm_output(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False, dynamic_eval=(0,0),duration_adjust=True,start_end_sp=False): return MLMCollateFn(
mlm_model,train_args = load_model(model_name) feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
mean_phn_span=args.model_conf['mean_phn_span'],
attention_window=attention_window,
pad_speech=pad_speech,
sega_emb=sega_emb,
duration_collect=duration_collect)
def get_mlm_output(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
dynamic_eval=(0, 0),
duration_adjust=True,
start_end_sp=False):
mlm_model, train_args = load_model(model_name)
mlm_model.eval() mlm_model.eval()
processor = None processor = None
collate_fn = build_collate_fn(train_args, False) collate_fn = build_collate_fn(train_args, False)
return decode_with_model(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing, return decode_with_model(
duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args = train_args) uid,
prefix,
def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vocoder, prefix='dump/raw/dev', model_name="conformer", old_str="",new_str="",prompt_decoding=False,dynamic_eval=(0,0), task_name = None): clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=sid,
decoder=decoder,
use_teacher_forcing=use_teacher_forcing,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
def test_vctk(uid,
clone_uid,
clone_prefix,
source_language,
target_language,
vocoder,
prefix='dump/raw/dev',
model_name="conformer",
old_str="",
new_str="",
prompt_decoding=False,
dynamic_eval=(0, 0),
task_name=None):
duration_preditor_path = None duration_preditor_path = None
spemd = None spemd = None
full_origin_str,wav_path = read_data(uid, prefix) full_origin_str, wav_path = read_data(uid, prefix)
if task_name == 'edit': if task_name == 'edit':
new_str = new_str new_str = new_str
elif task_name == 'synthesize': elif task_name == 'synthesize':
new_str = full_origin_str + new_str new_str = full_origin_str + new_str
else: else:
new_str = full_origin_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) new_str = full_origin_str + ' '.join(
[ch for ch in new_str if is_chinese(ch)])
print('new_str is ', new_str) print('new_str is ', new_str)
if not old_str: if not old_str:
old_str = full_origin_str old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd) results_dict, old_span = plot_mel_and_vocode_wav(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
vocoder,
duration_preditor_path,
sid=spemd)
return results_dict return results_dict
if __name__ == "__main__": if __name__ == "__main__":
# parse config and args
args = parse_args() args = parse_args()
print(args)
data_dict = test_vctk(args.uid, data_dict = test_vctk(
args.clone_uid, args.uid,
args.clone_prefix, args.clone_uid,
args.source_language, args.clone_prefix,
args.target_language, args.source_language,
args.target_language,
args.use_pt_vocoder, args.use_pt_vocoder,
args.prefix, args.prefix,
args.model_name, args.model_name,
new_str=args.new_str, new_str=args.new_str,
task_name=args.task_name) task_name=args.task_name)
sf.write('./wavs/%s' % args.output_name, data_dict['output'], samplerate=24000) sf.write(args.output_name, data_dict['output'], samplerate=24000)
print("finished...") print("finished...")
# exit() # exit()
import argparse import argparse
import logging
import math
import os
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional from typing import Dict
from typing import List from typing import List
from typing import Sequence from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import humanfriendly
from matplotlib.collections import Collection
from matplotlib.pyplot import axis
import librosa
import soundfile as sf
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F
from paddle import nn
from typeguard import check_argument_types
import logging
import math
import yaml import yaml
from abc import ABC, abstractmethod from paddle import nn
import warnings
from paddle.amp import auto_cast
import sys, os
pypath = '..' pypath = '..'
for dir_name in os.listdir(pypath): for dir_name in os.listdir(pypath):
dir_path = os.path.join(pypath, dir_name) dir_path = os.path.join(pypath, dir_name)
if os.path.isdir(dir_path): if os.path.isdir(dir_path):
sys.path.append(dir_path) sys.path.append(dir_path)
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictor
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
from paddlespeech.t2s.modules.masked_fill import masked_fill from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear, MultiLayeredConv1d from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
from paddlespeech.t2s.modules.transformer.repeat import repeat from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.layer_norm import LayerNorm from paddlespeech.t2s.modules.layer_norm import LayerNorm
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
class Swish(nn.Layer):
"""Construct an Swish object."""
def forward(self, x):
"""Return Swich activation function."""
return x * F.sigmoid(x)
def get_activation(act):
"""Return activation function."""
activation_funcs = {
"hardtanh": nn.Hardtanh,
"tanh": nn.Tanh,
"relu": nn.ReLU,
"selu": nn.SELU,
"swish": Swish,
}
return activation_funcs[act]()
class LegacyRelPositionalEncoding(PositionalEncoding): class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version). """Relative positional encoding module (old version).
...@@ -89,6 +53,7 @@ class LegacyRelPositionalEncoding(PositionalEncoding): ...@@ -89,6 +53,7 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
max_len (int): Maximum input length. max_len (int): Maximum input length.
""" """
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
""" """
Args: Args:
...@@ -102,20 +67,18 @@ class LegacyRelPositionalEncoding(PositionalEncoding): ...@@ -102,20 +67,18 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
"""Reset the positional encodings.""" """Reset the positional encodings."""
if self.pe is not None: if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]: if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
# if self.pe.dtype != x.dtype or self.pe.device != x.device:
# self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model)) pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse: if self.reverse:
position = paddle.arange( position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0, dtype=paddle.float32 paddle.shape(x)[1] - 1, -1, -1.0,
).unsqueeze(1) dtype=paddle.float32).unsqueeze(1)
else: else:
position = paddle.arange(0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp( div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
* -(math.log(10000.0) / self.d_model) -(math.log(10000.0) / self.d_model))
)
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
...@@ -129,46 +92,11 @@ class LegacyRelPositionalEncoding(PositionalEncoding): ...@@ -129,46 +92,11 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]] pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
def dump_tensor(var, do_trans = False):
wf = open('/mnt/home/xiaoran/PaddleSpeech-develop/tmp_var.out', 'w')
for num in var.shape:
wf.write(str(num) + ' ')
wf.write('\n')
if do_trans:
var = paddle.transpose(var, [1,0])
if len(var.shape)==1:
for _var in var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
elif len(var.shape)==2:
for __var in var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
elif len(var.shape)==3:
for ___var in var:
for __var in ___var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
wf.write('\n')
elif len(var.shape)==4:
for ____var in var:
for ___var in ____var:
for __var in ___var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
wf.write('\n')
wf.write('\n')
class mySequential(nn.Sequential): class mySequential(nn.Sequential):
def forward(self, *inputs): def forward(self, *inputs):
...@@ -179,24 +107,29 @@ class mySequential(nn.Sequential): ...@@ -179,24 +107,29 @@ class mySequential(nn.Sequential):
inputs = module(inputs) inputs = module(inputs)
return inputs return inputs
class NewMaskInputLayer(nn.Layer): class NewMaskInputLayer(nn.Layer):
__constants__ = ['out_features'] __constants__ = ['out_features']
out_features: int out_features: int
def __init__(self, out_features: int, def __init__(self, out_features: int, device=None, dtype=None) -> None:
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super(NewMaskInputLayer, self).__init__() super().__init__()
self.mask_feature = paddle.create_parameter( self.mask_feature = paddle.create_parameter(
shape=(1,1,out_features), shape=(1, 1, out_features),
dtype=paddle.float32, dtype=paddle.float32,
default_initializer=paddle.nn.initializer.Assign(paddle.normal(shape=(1,1,out_features)))) default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor, masked_position=None) -> paddle.Tensor:
masked_position = paddle.expand_as(paddle.unsqueeze(masked_position, -1), input) def forward(self, input: paddle.Tensor,
masked_input = masked_fill(input, masked_position, 0) + masked_fill(paddle.expand_as(self.mask_feature, input), ~masked_position, 0) masked_position=None) -> paddle.Tensor:
masked_position = paddle.expand_as(
paddle.unsqueeze(masked_position, -1), input)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
return masked_input return masked_input
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version). """Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816. Details can be found in https://github.com/espnet/espnet/pull/2816.
...@@ -266,7 +199,8 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -266,7 +199,8 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
q = paddle.transpose(q, [0, 2, 1, 3]) q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0] n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3]) p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
...@@ -278,17 +212,20 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -278,17 +212,20 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u, paddle.transpose(k, [0, 1, 3, 2])) matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch, head, time1, time1) # (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v, paddle.transpose(p, [0, 1, 3, 2])) matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd) matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2) # (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask)
class MLMEncoder(nn.Layer): class MLMEncoder(nn.Layer):
"""Conformer encoder module. """Conformer encoder module.
...@@ -324,40 +261,39 @@ class MLMEncoder(nn.Layer): ...@@ -324,40 +261,39 @@ class MLMEncoder(nn.Layer):
signature.) signature.)
""" """
def __init__(
self, def __init__(self,
idim, idim,
vocab_size=0, vocab_size=0,
pre_speech_layer: int = 0, pre_speech_layer: int=0,
attention_dim=256, attention_dim=256,
attention_heads=4, attention_heads=4,
linear_units=2048, linear_units=2048,
num_blocks=6, num_blocks=6,
dropout_rate=0.1, dropout_rate=0.1,
positional_dropout_rate=0.1, positional_dropout_rate=0.1,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
input_layer="conv2d", input_layer="conv2d",
normalize_before=True, normalize_before=True,
concat_after=False, concat_after=False,
positionwise_layer_type="linear", positionwise_layer_type="linear",
positionwise_conv_kernel_size=1, positionwise_conv_kernel_size=1,
macaron_style=False, macaron_style=False,
pos_enc_layer_type="abs_pos", pos_enc_layer_type="abs_pos",
pos_enc_class=None, pos_enc_class=None,
selfattention_layer_type="selfattn", selfattention_layer_type="selfattn",
activation_type="swish", activation_type="swish",
use_cnn_module=False, use_cnn_module=False,
zero_triu=False, zero_triu=False,
cnn_module_kernel=31, cnn_module_kernel=31,
padding_idx=-1, padding_idx=-1,
stochastic_depth_rate=0.0, stochastic_depth_rate=0.0,
intermediate_layers=None, intermediate_layers=None,
text_masking = False text_masking=False):
):
"""Construct an Encoder object.""" """Construct an Encoder object."""
super(MLMEncoder, self).__init__() super().__init__()
self._output_size = attention_dim self._output_size = attention_dim
self.text_masking=text_masking self.text_masking = text_masking
if self.text_masking: if self.text_masking:
self.text_masking_layer = NewMaskInputLayer(attention_dim) self.text_masking_layer = NewMaskInputLayer(attention_dim)
activation = get_activation(activation_type) activation = get_activation(activation_type)
...@@ -381,21 +317,18 @@ class MLMEncoder(nn.Layer): ...@@ -381,21 +317,18 @@ class MLMEncoder(nn.Layer):
nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim),
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
nn.ReLU(), nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate), pos_enc_class(attention_dim, positional_dropout_rate), )
)
elif input_layer == "conv2d": elif input_layer == "conv2d":
self.embed = Conv2dSubsampling( self.embed = Conv2dSubsampling(
idim, idim,
attention_dim, attention_dim,
dropout_rate, dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate), pos_enc_class(attention_dim, positional_dropout_rate), )
)
self.conv_subsampling_factor = 4 self.conv_subsampling_factor = 4
elif input_layer == "embed": elif input_layer == "embed":
self.embed = nn.Sequential( self.embed = nn.Sequential(
nn.Embedding(idim, attention_dim, padding_idx=padding_idx), nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), pos_enc_class(attention_dim, positional_dropout_rate), )
)
elif input_layer == "mlm": elif input_layer == "mlm":
self.segment_emb = None self.segment_emb = None
self.speech_embed = mySequential( self.speech_embed = mySequential(
...@@ -403,34 +336,31 @@ class MLMEncoder(nn.Layer): ...@@ -403,34 +336,31 @@ class MLMEncoder(nn.Layer):
nn.Linear(idim, attention_dim), nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim),
nn.ReLU(), nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate) pos_enc_class(attention_dim, positional_dropout_rate))
)
self.text_embed = nn.Sequential( self.text_embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx), nn.Embedding(
pos_enc_class(attention_dim, positional_dropout_rate), vocab_size, attention_dim, padding_idx=padding_idx),
) pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer=="sega_mlm": elif input_layer == "sega_mlm":
self.segment_emb = nn.Embedding(500, attention_dim, padding_idx=padding_idx) self.segment_emb = nn.Embedding(
500, attention_dim, padding_idx=padding_idx)
self.speech_embed = mySequential( self.speech_embed = mySequential(
NewMaskInputLayer(idim), NewMaskInputLayer(idim),
nn.Linear(idim, attention_dim), nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim),
nn.ReLU(), nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate) pos_enc_class(attention_dim, positional_dropout_rate))
)
self.text_embed = nn.Sequential( self.text_embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx), nn.Embedding(
pos_enc_class(attention_dim, positional_dropout_rate), vocab_size, attention_dim, padding_idx=padding_idx),
) pos_enc_class(attention_dim, positional_dropout_rate), )
elif isinstance(input_layer, nn.Layer): elif isinstance(input_layer, nn.Layer):
self.embed = nn.Sequential( self.embed = nn.Sequential(
input_layer, input_layer,
pos_enc_class(attention_dim, positional_dropout_rate), pos_enc_class(attention_dim, positional_dropout_rate), )
)
elif input_layer is None: elif input_layer is None:
self.embed = nn.Sequential( self.embed = nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate) pos_enc_class(attention_dim, positional_dropout_rate))
)
else: else:
raise ValueError("unknown input_layer: " + input_layer) raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before self.normalize_before = normalize_before
...@@ -439,57 +369,39 @@ class MLMEncoder(nn.Layer): ...@@ -439,57 +369,39 @@ class MLMEncoder(nn.Layer):
if selfattention_layer_type == "selfattn": if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention") logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = ( encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_heads, attention_dropout_rate, )
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn": elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos" assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = ( encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_heads, attention_dropout_rate, )
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "rel_selfattn": elif selfattention_layer_type == "rel_selfattn":
logging.info("encoder self-attention layer type = relative self-attention") logging.info(
"encoder self-attention layer type = relative self-attention")
assert pos_enc_layer_type == "rel_pos" assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = ( encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_heads, attention_dropout_rate, zero_triu, )
attention_dim,
attention_dropout_rate,
zero_triu,
)
else: else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) raise ValueError("unknown encoder_attn_layer: " +
selfattention_layer_type)
# feed-forward module definition # feed-forward module definition
if positionwise_layer_type == "linear": if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = ( positionwise_layer_args = (attention_dim, linear_units,
attention_dim, dropout_rate, activation, )
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d": elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = ( positionwise_layer_args = (attention_dim, linear_units,
attention_dim, positionwise_conv_kernel_size,
linear_units, dropout_rate, )
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear": elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear positionwise_layer = Conv1dLinear
positionwise_layer_args = ( positionwise_layer_args = (attention_dim, linear_units,
attention_dim, positionwise_conv_kernel_size,
linear_units, dropout_rate, )
positionwise_conv_kernel_size,
dropout_rate,
)
else: else:
raise NotImplementedError("Support only linear or conv1d.") raise NotImplementedError("Support only linear or conv1d.")
...@@ -508,9 +420,7 @@ class MLMEncoder(nn.Layer): ...@@ -508,9 +420,7 @@ class MLMEncoder(nn.Layer):
dropout_rate, dropout_rate,
normalize_before, normalize_before,
concat_after, concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks, stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
),
)
self.pre_speech_layer = pre_speech_layer self.pre_speech_layer = pre_speech_layer
self.pre_speech_encoders = repeat( self.pre_speech_encoders = repeat(
self.pre_speech_layer, self.pre_speech_layer,
...@@ -523,16 +433,21 @@ class MLMEncoder(nn.Layer): ...@@ -523,16 +433,21 @@ class MLMEncoder(nn.Layer):
dropout_rate, dropout_rate,
normalize_before, normalize_before,
concat_after, concat_after,
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
),
) )
if self.normalize_before: if self.normalize_before:
self.after_norm = LayerNorm(attention_dim) self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers self.intermediate_layers = intermediate_layers
def forward(self,
def forward(self, speech_pad, text_pad, masked_position, speech_mask=None, text_mask=None,speech_segment_pos=None, text_segment_pos=None): speech_pad,
text_pad,
masked_position,
speech_mask=None,
text_mask=None,
speech_segment_pos=None,
text_segment_pos=None):
"""Encode input sequence. """Encode input sequence.
""" """
...@@ -542,12 +457,13 @@ class MLMEncoder(nn.Layer): ...@@ -542,12 +457,13 @@ class MLMEncoder(nn.Layer):
speech_pad = self.speech_embed(speech_pad) speech_pad = self.speech_embed(speech_pad)
# pure speech input # pure speech input
if -2 in np.array(text_pad): if -2 in np.array(text_pad):
text_pad = text_pad+3 text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1) text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_segment_pos = paddle.zeros_like(text_pad) text_segment_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), text_pad[1]) text_pad = (text_pad[0] + self.segment_emb(text_segment_pos),
text_segment_pos=None text_pad[1])
text_segment_pos = None
elif text_pad is not None: elif text_pad is not None:
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
segment_emb = None segment_emb = None
...@@ -556,32 +472,32 @@ class MLMEncoder(nn.Layer): ...@@ -556,32 +472,32 @@ class MLMEncoder(nn.Layer):
text_segment_emb = self.segment_emb(text_segment_pos) text_segment_emb = self.segment_emb(text_segment_pos)
text_pad = (text_pad[0] + text_segment_emb, text_pad[1]) text_pad = (text_pad[0] + text_segment_emb, text_pad[1])
speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1]) speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1])
segment_emb = paddle.concat([speech_segment_emb, text_segment_emb],axis=1) segment_emb = paddle.concat(
[speech_segment_emb, text_segment_emb], axis=1)
if self.pre_speech_encoders: if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
if text_pad is not None: if text_pad is not None:
xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1) xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1)
xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1) xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1)
masks = paddle.concat([speech_mask,text_mask],axis=-1) masks = paddle.concat([speech_mask, text_mask], axis=-1)
else: else:
xs = speech_pad[0] xs = speech_pad[0]
xs_pos_emb = speech_pad[1] xs_pos_emb = speech_pad[1]
masks = speech_mask masks = speech_mask
xs, masks = self.encoders((xs,xs_pos_emb), masks) xs, masks = self.encoders((xs, xs_pos_emb), masks)
if isinstance(xs, tuple): if isinstance(xs, tuple):
xs = xs[0] xs = xs[0]
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
return xs, masks #, segment_emb return xs, masks #, segment_emb
class MLMDecoder(MLMEncoder): class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_position=None, segment_emb=None):
def forward(self, xs, masks, masked_position=None,segment_emb=None):
"""Encode input sequence. """Encode input sequence.
Args: Args:
...@@ -596,9 +512,6 @@ class MLMDecoder(MLMEncoder): ...@@ -596,9 +512,6 @@ class MLMDecoder(MLMEncoder):
emb, mlm_position = None, None emb, mlm_position = None, None
if not self.training: if not self.training:
masked_position = None masked_position = None
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
# xs, masks = self.embed(xs, masks)
# else:
xs = self.embed(xs) xs = self.embed(xs)
if segment_emb: if segment_emb:
xs = (xs[0] + segment_emb, xs[1]) xs = (xs[0] + segment_emb, xs[1])
...@@ -609,10 +522,8 @@ class MLMDecoder(MLMEncoder): ...@@ -609,10 +522,8 @@ class MLMDecoder(MLMEncoder):
for layer_idx, encoder_layer in enumerate(self.encoders): for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks) xs, masks = encoder_layer(xs, masks)
if ( if (self.intermediate_layers is not None and
self.intermediate_layers is not None layer_idx + 1 in self.intermediate_layers):
and layer_idx + 1 in self.intermediate_layers
):
encoder_output = xs encoder_output = xs
# intermediate branches also require normalization. # intermediate branches also require normalization.
if self.normalize_before: if self.normalize_before:
...@@ -627,104 +538,44 @@ class MLMDecoder(MLMEncoder): ...@@ -627,104 +538,44 @@ class MLMDecoder(MLMEncoder):
return xs, masks, intermediate_outputs return xs, masks, intermediate_outputs
return xs, masks return xs, masks
class AbsESPnetModel(nn.Layer, ABC):
"""The common abstract class among each tasks
"ESPnetModel" is referred to a class which inherits paddle.nn.Layer,
and makes the dnn-models forward as its member field,
a.k.a delegate pattern,
and defines "loss", "stats", and "weight" for the task.
If you intend to implement new task in ESPNet,
the model must inherit this class.
In other words, the "mediator" objects between
our training system and the your task class are
just only these three values, loss, stats, and weight.
Example:
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
"""
@abstractmethod
def forward(
self, **batch: paddle.Tensor
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
raise NotImplementedError
@abstractmethod def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
def collect_feats(self, **batch: paddle.Tensor) -> Dict[str, paddle.Tensor]:
raise NotImplementedError
class AbsFeatsExtract(nn.Layer, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_parameters(self) -> Dict[str, Any]:
raise NotImplementedError
@abstractmethod
def forward(
self, input: paddle.Tensor, input_lengths: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
raise NotImplementedError
class AbsNormalize(nn.Layer, ABC):
@abstractmethod
def forward(
self, input: paddle.Tensor, input_lengths: paddle.Tensor = None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
# return output, output_lengths
raise NotImplementedError
def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window):
round = max_len % attention_window round = max_len % attention_window
if round != 0: if round != 0:
max_tlen += (attention_window - round) max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0] n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(shape = (n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) text_pad = paddle.zeros(
shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]),
dtype=text.dtype)
for i in range(n_batch): for i in range(n_batch):
text_pad[i, : paddle.shape(text[i])[0]] = text[i] text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else: else:
text_pad = text[:, : max_tlen] text_pad = text[:, :max_tlen]
return text_pad, max_tlen return text_pad, max_tlen
class ESPnetMLMModel(AbsESPnetModel):
def __init__( class MLMModel(nn.Layer):
self, def __init__(self,
token_list: Union[Tuple[str, ...], List[str]], token_list: Union[Tuple[str, ...], List[str]],
odim: int, odim: int,
feats_extract: Optional[AbsFeatsExtract], encoder: nn.Layer,
normalize: Optional[AbsNormalize], decoder: Optional[nn.Layer],
encoder: nn.Layer, postnet_layers: int=0,
decoder: Optional[nn.Layer], postnet_chans: int=0,
postnet_layers: int = 0, postnet_filts: int=0,
postnet_chans: int = 0, ignore_id: int=-1,
postnet_filts: int = 0, lsm_weight: float=0.0,
ignore_id: int = -1, length_normalized_loss: bool=False,
lsm_weight: float = 0.0, report_cer: bool=True,
length_normalized_loss: bool = False, report_wer: bool=True,
report_cer: bool = True, sym_space: str="<space>",
report_wer: bool = True, sym_blank: str="<blank>",
sym_space: str = "<space>", masking_schema: str="span",
sym_blank: str = "<blank>", mean_phn_span: int=3,
masking_schema: str = "span", mlm_prob: float=0.25,
mean_phn_span: int = 3, dynamic_mlm_prob=False,
mlm_prob: float = 0.25, decoder_seg_pos=False,
dynamic_mlm_prob = False, text_masking=False):
decoder_seg_pos=False,
text_masking=False
):
super().__init__() super().__init__()
# note that eos is the same as sos (equivalent ID) # note that eos is the same as sos (equivalent ID)
...@@ -732,105 +583,119 @@ class ESPnetMLMModel(AbsESPnetModel): ...@@ -732,105 +583,119 @@ class ESPnetMLMModel(AbsESPnetModel):
self.ignore_id = ignore_id self.ignore_id = ignore_id
self.token_list = token_list.copy() self.token_list = token_list.copy()
self.normalize = normalize
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.vocab_size = encoder.text_embed[0]._num_embeddings self.vocab_size = encoder.text_embed[0]._num_embeddings
if report_cer or report_wer: if report_cer or report_wer:
self.error_calculator = ErrorCalculator( self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer token_list, sym_space, sym_blank, report_cer, report_wer)
)
else: else:
self.error_calculator = None self.error_calculator = None
self.feats_extract = feats_extract
self.mlm_weight = 1.0 self.mlm_weight = 1.0
self.mlm_prob = mlm_prob self.mlm_prob = mlm_prob
self.mlm_layer = 12 self.mlm_layer = 12
self.finetune_wo_mlm =True self.finetune_wo_mlm = True
self.max_span = 50 self.max_span = 50
self.min_span = 4 self.min_span = 4
self.mean_phn_span = mean_phn_span self.mean_phn_span = mean_phn_span
self.masking_schema = masking_schema self.masking_schema = masking_schema
if self.decoder is None or not (hasattr(self.decoder, 'output_layer') and self.decoder.output_layer is not None): if self.decoder is None or not (hasattr(self.decoder,
'output_layer') and
self.decoder.output_layer is not None):
self.sfc = nn.Linear(self.encoder._output_size, odim) self.sfc = nn.Linear(self.encoder._output_size, odim)
else: else:
self.sfc=None self.sfc = None
if text_masking: if text_masking:
self.text_sfc = nn.Linear(self.encoder.text_embed[0]._embedding_dim, self.vocab_size, weight_attr = self.encoder.text_embed[0]._weight_attr) self.text_sfc = nn.Linear(
self.encoder.text_embed[0]._embedding_dim,
self.vocab_size,
weight_attr=self.encoder.text_embed[0]._weight_attr)
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
else: else:
self.text_sfc = None self.text_sfc = None
self.text_mlm_loss = None self.text_mlm_loss = None
self.decoder_seg_pos = decoder_seg_pos self.decoder_seg_pos = decoder_seg_pos
if lsm_weight > 50: if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss(reduce=False) self.l1_loss_func = nn.MSELoss()
else: else:
self.l1_loss_func = nn.L1Loss(reduction='none') self.l1_loss_func = nn.L1Loss(reduction='none')
self.postnet = ( self.postnet = (None if postnet_layers == 0 else Postnet(
None idim=self.encoder._output_size,
if postnet_layers == 0 odim=odim,
else Postnet( n_layers=postnet_layers,
idim=self.encoder._output_size, n_chans=postnet_chans,
odim=odim, n_filts=postnet_filts,
n_layers=postnet_layers, use_batch_norm=True,
n_chans=postnet_chans, dropout_rate=0.5, ))
n_filts=postnet_filts,
use_batch_norm=True,
dropout_rate=0.5,
)
)
def collect_feats(self, def collect_feats(self,
speech, speech_lengths, text, text_lengths, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos, y_masks=None speech,
) -> Dict[str, paddle.Tensor]: speech_lengths,
text,
text_lengths,
masked_position,
speech_mask,
text_mask,
speech_segment_pos,
text_segment_pos,
y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lengths": speech_lengths} return {"feats": speech, "feats_lengths": speech_lengths}
def _forward(self, batch, speech_segment_pos,y_masks=None): def forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
if self.decoder is not None: if self.decoder is not None:
ys_in = self._add_first_frame_and_remove_last_frame(batch['speech_pad']) ys_in = self._add_first_frame_and_remove_last_frame(
batch['speech_pad'])
encoder_out, h_masks = self.encoder(**batch) encoder_out, h_masks = self.encoder(**batch)
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out, bool(h_masks), self.encoder.segment_emb(speech_segment_pos)) zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks),
self.encoder.segment_emb(speech_segment_pos))
speech_hidden_states = zs speech_hidden_states = zs
else: else:
speech_hidden_states = encoder_out[:,:paddle.shape(batch['speech_pad'])[1], :] speech_hidden_states = encoder_out[:, :paddle.shape(batch[
'speech_pad'])[1], :]
if self.sfc is not None: if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim)) before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else: else:
before_outs = speech_hidden_states before_outs = speech_hidden_states
if self.postnet is not None: if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet( after_outs = before_outs + paddle.transpose(
paddle.transpose(before_outs, [0, 2, 1]) self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
), (0, 2, 1)) (0, 2, 1))
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch['masked_position'] return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
def inference( def inference(
self, self,
speech, text, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos, speech,
span_boundary, text,
y_masks=None, masked_position,
speech_lengths=None, text_lengths=None, speech_mask,
feats: Optional[paddle.Tensor] = None, text_mask,
spembs: Optional[paddle.Tensor] = None, speech_segment_pos,
sids: Optional[paddle.Tensor] = None, text_segment_pos,
lids: Optional[paddle.Tensor] = None, span_boundary,
threshold: float = 0.5, y_masks=None,
minlenratio: float = 0.0, speech_lengths=None,
maxlenratio: float = 10.0, text_lengths=None,
use_teacher_forcing: bool = False, feats: Optional[paddle.Tensor]=None,
) -> Dict[str, paddle.Tensor]: spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
threshold: float=0.5,
minlenratio: float=0.0,
maxlenratio: float=10.0,
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
batch = dict( batch = dict(
speech_pad=speech, speech_pad=speech,
text_pad=text, text_pad=text,
...@@ -838,119 +703,130 @@ class ESPnetMLMModel(AbsESPnetModel): ...@@ -838,119 +703,130 @@ class ESPnetMLMModel(AbsESPnetModel):
speech_mask=speech_mask, speech_mask=speech_mask,
text_mask=text_mask, text_mask=text_mask,
speech_segment_pos=speech_segment_pos, speech_segment_pos=speech_segment_pos,
text_segment_pos=text_segment_pos, text_segment_pos=text_segment_pos, )
)
# # inference with teacher forcing # # inference with teacher forcing
# hs, h_masks = self.encoder(**batch) # hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:,:span_boundary[0]]] outs = [batch['speech_pad'][:, :span_boundary[0]]]
z_cache = None z_cache = None
if use_teacher_forcing: if use_teacher_forcing:
before,zs, _, _ = self._forward( before, zs, _, _ = self.forward(
batch, speech_segment_pos, y_masks=y_masks) batch, speech_segment_pos, y_masks=y_masks)
if zs is None: if zs is None:
zs = before zs = before
outs+=[zs[0][span_boundary[0]:span_boundary[1]]] outs += [zs[0][span_boundary[0]:span_boundary[1]]]
outs+=[batch['speech_pad'][:,span_boundary[1]:]] outs += [batch['speech_pad'][:, span_boundary[1]:]]
return dict(feat_gen=outs) return dict(feat_gen=outs)
return None
# concatenate attention weights -> (#layers, #heads, T_feats, T_text)
att_ws = paddle.stack(att_ws, axis=0)
outs += [batch['speech_pad'][:,span_boundary[1]:]]
return dict(feat_gen=outs, att_w=att_ws)
def _add_first_frame_and_remove_last_frame(
def _add_first_frame_and_remove_last_frame(self, ys: paddle.Tensor) -> paddle.Tensor: self, ys: paddle.Tensor) -> paddle.Tensor:
ys_in = paddle.concat( ys_in = paddle.concat(
[paddle.zeros(shape = (paddle.shape(ys)[0], 1, paddle.shape(ys)[2]), dtype = ys.dtype), ys[:, :-1]], axis=1 [
) paddle.zeros(
shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]),
dtype=ys.dtype), ys[:, :-1]
],
axis=1)
return ys_in return ys_in
class ESPnetMLMEncAsDecoderModel(ESPnetMLMModel): class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_segment_pos, y_masks=None):
def _forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
else: else:
zs = encoder_out zs = encoder_out
speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :] speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
if self.sfc is not None: if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim)) before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else: else:
before_outs = speech_hidden_states before_outs = speech_hidden_states
if self.postnet is not None: if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet( after_outs = before_outs + paddle.transpose(
paddle.transpose(before_outs, [0, 2, 1]) self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
), [0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch['masked_position'] return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
class ESPnetMLMDualMaksingModel(ESPnetMLMModel):
def _calc_mlm_loss( class MLMDualMaksingModel(MLMModel):
self, def _calc_mlm_loss(self,
before_outs: paddle.Tensor, before_outs: paddle.Tensor,
after_outs: paddle.Tensor, after_outs: paddle.Tensor,
text_outs: paddle.Tensor, text_outs: paddle.Tensor,
batch batch):
):
xs_pad = batch['speech_pad'] xs_pad = batch['speech_pad']
text_pad = batch['text_pad'] text_pad = batch['text_pad']
masked_position = batch['masked_position'] masked_position = batch['masked_position']
text_masked_position = batch['text_masked_position'] text_masked_position = batch['text_masked_position']
mlm_loss_position = masked_position>0 mlm_loss_position = masked_position > 0
loss = paddle.sum(self.l1_loss_func(paddle.reshape(before_outs, (-1, self.odim)), loss = paddle.sum(
paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None: if after_outs is not None:
loss += paddle.sum(self.l1_loss_func(paddle.reshape(after_outs, (-1, self.odim)), loss += paddle.sum(
paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) self.l1_loss_func(
loss_mlm = paddle.sum((loss * paddle.reshape(mlm_loss_position, axis=-1).float())) \ paddle.reshape(after_outs, (-1, self.odim)),
/ paddle.sum((mlm_loss_position.float()) + 1e-10) paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_text = paddle.sum((self.text_mlm_loss(paddle.reshape(text_outs, (-1,self.vocab_size)), paddle.reshape(text_pad, (-1))) * paddle.reshape(text_masked_position, (-1)).float())) \ loss_mlm = paddle.sum((loss * paddle.reshape(
/ paddle.sum((text_masked_position.float()) + 1e-10) mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_position,
(-1)))) / paddle.sum((text_masked_position) + 1e-10)
return loss_mlm, loss_text return loss_mlm, loss_text
def forward(self, batch, speech_segment_pos, y_masks=None):
def _forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
else: else:
zs = encoder_out zs = encoder_out
speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :] speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
if self.text_sfc: if self.text_sfc:
text_hiddent_states = zs[:,paddle.shape(batch['speech_pad'])[1]:,:] text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[
text_outs = paddle.reshape(self.text_sfc(text_hiddent_states), (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) 1]:, :]
text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
if self.sfc is not None: if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states), before_outs = paddle.reshape(
(paddle.shape(speech_hidden_states)[0], -1, self.odim)) self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else: else:
before_outs = speech_hidden_states before_outs = speech_hidden_states
if self.postnet is not None: if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet( after_outs = before_outs + paddle.transpose(
paddle.transpose(before_outs, [0,2,1]) self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
), [0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs,text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position']
def build_model_from_file(config_file, model_file): def build_model_from_file(config_file, model_file):
state_dict = paddle.load(model_file) state_dict = paddle.load(model_file)
model_class = ESPnetMLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else ESPnetMLMEncAsDecoderModel else MLMEncAsDecoderModel
# 构建模型 # 构建模型
args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8")) args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8"))
...@@ -962,7 +838,8 @@ def build_model_from_file(config_file, model_file): ...@@ -962,7 +838,8 @@ def build_model_from_file(config_file, model_file):
return model, args return model, args
def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderModel) -> ESPnetMLMModel: def build_model(args: argparse.Namespace,
model_class=MLMEncAsDecoderModel) -> MLMModel:
if isinstance(args.token_list, str): if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f: with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f] token_list = [line.rstrip() for line in f]
...@@ -975,17 +852,14 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod ...@@ -975,17 +852,14 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
raise RuntimeError("token_list must be str or list") raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list) vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }") logging.info(f"Vocabulary size: {vocab_size }")
odim = 80
odim = 80
# Normalization layer
normalize = None
pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding
if "conformer" == args.encoder: if "conformer" == args.encoder:
conformer_self_attn_layer_type = args.encoder_conf['selfattention_layer_type'] conformer_self_attn_layer_type = args.encoder_conf[
'selfattention_layer_type']
conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type'] conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type']
conformer_rel_pos_type = "legacy" conformer_rel_pos_type = "legacy"
if conformer_rel_pos_type == "legacy": if conformer_rel_pos_type == "legacy":
...@@ -994,38 +868,42 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod ...@@ -994,38 +868,42 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
logging.warning( logging.warning(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, " "due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'." "please use conformer_pos_enc_layer_type = 'latest'.")
)
if conformer_self_attn_layer_type == "rel_selfattn": if conformer_self_attn_layer_type == "rel_selfattn":
conformer_self_attn_layer_type = "legacy_rel_selfattn" conformer_self_attn_layer_type = "legacy_rel_selfattn"
logging.warning( logging.warning(
"Fallback to " "Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, " "due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'." "please use conformer_pos_enc_layer_type = 'latest'.")
)
elif conformer_rel_pos_type == "latest": elif conformer_rel_pos_type == "latest":
assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_pos_enc_layer_type != "legacy_rel_pos"
assert conformer_self_attn_layer_type != "legacy_rel_selfattn" assert conformer_self_attn_layer_type != "legacy_rel_selfattn"
else: else:
raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}")
args.encoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type args.encoder_conf[
args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type 'selfattention_layer_type'] = conformer_self_attn_layer_type
if "conformer"==args.decoder: args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type
args.decoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type if "conformer" == args.decoder:
args.decoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type args.decoder_conf[
'selfattention_layer_type'] = conformer_self_attn_layer_type
args.decoder_conf[
'pos_enc_layer_type'] = conformer_pos_enc_layer_type
# Encoder # Encoder
encoder_class = MLMEncoder encoder_class = MLMEncoder
if 'text_masking' in args.model_conf.keys() and args.model_conf['text_masking']: if 'text_masking' in args.model_conf.keys() and args.model_conf[
'text_masking']:
args.encoder_conf['text_masking'] = True args.encoder_conf['text_masking'] = True
else: else:
args.encoder_conf['text_masking'] = False args.encoder_conf['text_masking'] = False
encoder = encoder_class(args.input_size,vocab_size=vocab_size, pos_enc_class=pos_enc_class, encoder = encoder_class(
**args.encoder_conf) args.input_size,
vocab_size=vocab_size,
pos_enc_class=pos_enc_class,
**args.encoder_conf)
# Decoder # Decoder
if args.decoder != 'no_decoder': if args.decoder != 'no_decoder':
...@@ -1033,22 +911,17 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod ...@@ -1033,22 +911,17 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
decoder = decoder_class( decoder = decoder_class(
idim=0, idim=0,
input_layer=None, input_layer=None,
**args.decoder_conf, **args.decoder_conf, )
)
else: else:
decoder = None decoder = None
# Build model # Build model
model = model_class( model = model_class(
feats_extract=None, # maybe should be LogMelFbank
odim=odim, odim=odim,
normalize=normalize,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
token_list=token_list, token_list=token_list,
**args.model_conf, **args.model_conf, )
)
# Initialize # Initialize
if args.init is not None: if args.init is not None:
......
p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525
Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125
p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825
p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125
Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625
p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325
p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp
Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp
p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
...@@ -5,7 +5,6 @@ from typing import List ...@@ -5,7 +5,6 @@ from typing import List
from typing import Union from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object. """Read a text file having 2 column as dict object.
...@@ -33,9 +32,8 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: ...@@ -33,9 +32,8 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
return data return data
def load_num_sequence_text( def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
path: Union[Path, str], loader_type: str = "csv_int" ) -> Dict[str, List[Union[float, int]]]:
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number """Read a text file indicating sequences of number
Examples: Examples:
...@@ -73,6 +71,7 @@ def load_num_sequence_text( ...@@ -73,6 +71,7 @@ def load_num_sequence_text(
try: try:
retval[k] = [dtype(i) for i in v.split(delimiter)] retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError: except TypeError:
logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"') logging.error(
f'Error happened with path="{path}", id="{k}", value="{v}"')
raise raise
return retval return retval
# en --> zh 的 语音合成 #!/bin/bash
# 根据Prompt_003_new作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的new_str需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python inference.py \ python inference.py \
--task_name cross-lingual_clone \ --task_name=cross-lingual_clone \
--model_name paddle_checkpoint_dual_mask_enzh \ --model_name=paddle_checkpoint_dual_mask_enzh \
--uid Prompt_003_new \ --uid=Prompt_003_new \
--new_str '今天天气很好.' \ --new_str='今天天气很好.' \
--prefix ./prompt/dev/ \ --prefix='./prompt/dev/' \
--source_language english \ --source_language=english \
--target_language chinese \ --target_language=chinese \
--output_name pred_clone.wav \ --output_name=pred_clone.wav \
--use_pt_vocoder False \ --use_pt_vocoder=False \
--voc pwgan_aishell3 \ --voc=pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_csmsc \ --am=fastspeech2_csmsc \
--am_config download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
\ No newline at end of file \ No newline at end of file
#!/bin/bash
# 纯英文的语音合成 # 纯英文的语音合成
# 样例为根据p299_096对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' # 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python inference.py \ python inference.py \
--task_name synthesize \ --task_name=synthesize \
--model_name paddle_checkpoint_en \ --model_name=paddle_checkpoint_en \
--uid p299_096 \ --uid=p299_096 \
--new_str 'I enjoy my life.' \ --new_str='I enjoy my life, do you?' \
--prefix ./prompt/dev/ \ --prefix='./prompt/dev/' \
--source_language english \ --source_language=english \
--target_language english \ --target_language=english \
--output_name pred_gen.wav \ --output_name=pred_gen.wav \
--use_pt_vocoder True \ --use_pt_vocoder=False \
--voc pwgan_aishell3 \ --voc=pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \ --am=fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file \ No newline at end of file
#!/bin/bash
# 纯英文的语音编辑 # 纯英文的语音编辑
# 样例为把p243_new对应的原始语音: For that reason cover should not be given.编辑成'for that reason cover is impossible to be given.'对应的语音 # 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中1个位置的替换或者插入文本操作 # NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python inference.py \ python inference.py \
--task_name edit \ --task_name=edit \
--model_name paddle_checkpoint_en \ --model_name=paddle_checkpoint_en \
--uid p243_new \ --uid=p243_new \
--new_str 'for that reason cover is impossible to be given.' \ --new_str='for that reason cover is impossible to be given.' \
--prefix ./prompt/dev/ \ --prefix='./prompt/dev/' \
--source_language english \ --source_language=english \
--target_language english \ --target_language=english \
--output_name pred_edit.wav \ --output_name=pred_edit.wav \
--use_pt_vocoder True \ --use_pt_vocoder=False \
--voc pwgan_aishell3 \ --voc=pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \ --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \ --am=fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
...@@ -86,7 +86,11 @@ def parse_args(): ...@@ -86,7 +86,11 @@ def parse_args():
parser.add_argument("--target_language", type=str, help="target language") parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--output_name", type=str, help="output name") parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name") parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument("--use_pt_vocoder", default=True, help="use pytorch version vocoder or not. [Note: only in english condition.]") parser.add_argument(
"--use_pt_vocoder",
type=str2bool,
default=True,
help="use pytorch version vocoder or not. [Note: only in english condition.]")
# pre # pre
args = parser.parse_args() args = parser.parse_args()
......
#!/bin/bash
rm -rf *.wav
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Wrapper class for the vocoder model trained with parallel_wavegan repo.""" """Wrapper class for the vocoder model trained with parallel_wavegan repo."""
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import yaml
import torch import torch
import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module): class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo.""" """Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__( def __init__(
self, self,
model_file: Union[Path, str], model_file: Union[Path, str],
config_file: Optional[Union[Path, str]] = None, config_file: Optional[Union[Path, str]]=None, ):
):
"""Initialize ParallelWaveGANPretrainedVocoder module.""" """Initialize ParallelWaveGANPretrainedVocoder module."""
super().__init__() super().__init__()
try: try:
...@@ -30,8 +23,7 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module): ...@@ -30,8 +23,7 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
except ImportError: except ImportError:
logging.error( logging.error(
"`parallel_wavegan` is not installed. " "`parallel_wavegan` is not installed. "
"Please install via `pip install -U parallel_wavegan`." "Please install via `pip install -U parallel_wavegan`.")
)
raise raise
if config_file is None: if config_file is None:
dirname = os.path.dirname(str(model_file)) dirname = os.path.dirname(str(model_file))
...@@ -59,5 +51,4 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module): ...@@ -59,5 +51,4 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
""" """
return self.vocoder.inference( return self.vocoder.inference(
feats, feats,
normalize_before=self.normalize_before, normalize_before=self.normalize_before, ).view(-1)
).view(-1)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import logging
from pathlib import Path
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
import soundfile as sf
import yaml import yaml
from timer import timer from sedit_arg_parser import parse_args
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_test_dataset from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.utils import str2bool
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from yacs.config import CfgNode
# new add
import paddle.nn.functional as F
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.exps.syn_utils import get_frontend
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder
from sedit_arg_parser import parse_args # new add
model_alias = { model_alias = {
# acoustic model # acoustic model
...@@ -58,9 +28,6 @@ model_alias = { ...@@ -58,9 +28,6 @@ model_alias = {
} }
def is_chinese(ch): def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff': if u'\u4e00' <= ch <= u'\u9fff':
return True return True
...@@ -69,17 +36,15 @@ def is_chinese(ch): ...@@ -69,17 +36,15 @@ def is_chinese(ch):
def build_vocoder_from_file( def build_vocoder_from_file(
vocoder_config_file = None, vocoder_config_file=None,
vocoder_file = None, vocoder_file=None,
model = None, model=None,
device = "cpu", device="cpu", ):
):
# Build vocoder # Build vocoder
if str(vocoder_file).endswith(".pkl"): if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan # If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder( vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file,
vocoder_file, vocoder_config_file vocoder_config_file)
)
return vocoder.to(device) return vocoder.to(device)
else: else:
...@@ -91,7 +56,7 @@ def get_voc_out(mel, target_language="chinese"): ...@@ -91,7 +56,7 @@ def get_voc_out(mel, target_language="chinese"):
args = parse_args() args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..." assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..."
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
...@@ -106,6 +71,7 @@ def get_voc_out(mel, target_language="chinese"): ...@@ -106,6 +71,7 @@ def get_voc_out(mel, target_language="chinese"):
# print("shepe of wav (time x n_channels):%s"%wav.shape) # print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav) return np.squeeze(wav)
# dygraph # dygraph
def get_am_inference(args, am_config): def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
...@@ -159,11 +125,14 @@ def get_am_inference(args, am_config): ...@@ -159,11 +125,14 @@ def get_am_inference(args, am_config):
return am, am_inference, am_name, am_dataset, phn_id return am, am_inference, am_name, am_dataset, phn_id
def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300): def evaluate_durations(phns,
target_language="chinese",
fs=24000,
hop_length=300):
args = parse_args() args = parse_args()
if target_language == 'english': if target_language == 'english':
args.lang='en' args.lang = 'en'
args.am = "fastspeech2_ljspeech" args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
...@@ -171,12 +140,12 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 ...@@ -171,12 +140,12 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese': elif target_language == 'chinese':
args.lang='zh' args.lang = 'zh'
args.am = "fastspeech2_csmsc" args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[]) # args = parser.parse_args(args=[])
if args.ngpu == 0: if args.ngpu == 0:
...@@ -186,8 +155,6 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 ...@@ -186,8 +155,6 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..." assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..."
# Init body. # Init body.
...@@ -197,8 +164,8 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 ...@@ -197,8 +164,8 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
# print(am_config) # print(am_config)
# print("---------------------") # print("---------------------")
# acoustic model # acoustic model
am, am_inference, am_name, am_dataset,phn_id = get_am_inference(args, am_config) am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args,
am_config)
torch_phns = phns torch_phns = phns
vocab_phones = {} vocab_phones = {}
...@@ -206,33 +173,31 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300 ...@@ -206,33 +173,31 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
vocab_phones[tone] = int(id) vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones)) # print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones) vocab_size = len(vocab_phones)
phonemes = [ phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
phn if phn in vocab_phones else "sp" for phn in torch_phns
]
phone_ids = [vocab_phones[item] for item in phonemes] phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids phone_ids_new = phone_ids
phone_ids_new.append(vocab_size-1) phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
normalized_mel, d_outs, p_outs, e_outs = am.inference(phone_ids_new, spk_id=None, spk_emb=None) normalized_mel, d_outs, p_outs, e_outs = am.inference(
phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new return phoneme_durations_new
def sentence2phns(sentence, target_language="en"): def sentence2phns(sentence, target_language="en"):
args = parse_args() args = parse_args()
if target_language == 'en': if target_language == 'en':
args.lang='en' args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh': elif target_language == 'zh':
args.lang='zh' args.lang = 'zh'
args.phones_dict="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else: else:
print("target_language should in {'zh', 'en'}!") print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args) frontend = get_frontend(args)
merge_sentences = True merge_sentences = True
get_tone_ids = False get_tone_ids = False
...@@ -246,10 +211,8 @@ def sentence2phns(sentence, target_language="en"): ...@@ -246,10 +211,8 @@ def sentence2phns(sentence, target_language="en"):
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes( phonemes = frontend.get_phonemes(
sentence, sentence, merge_sentences=merge_sentences, print_info=False)
merge_sentences=merge_sentences,
print_info=False)
return phonemes[0], input_ids["phone_ids"][0] return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en': elif target_language == 'en':
...@@ -270,16 +233,11 @@ def sentence2phns(sentence, target_language="en"): ...@@ -270,16 +233,11 @@ def sentence2phns(sentence, target_language="en"):
phones = [phn for phn in phones if not phn.isspace()] phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp # replace unk phone with sp
phones = [ phones = [
phn phn if (phn in vocab_phones and phn not in punc) else "sp"
if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones for phn in phones
] ]
phones_list.append(phones) phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0] return phones_list[0], input_ids["phone_ids"][0]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册