未验证 提交 437081f7 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #824 from yt605155624/format

[ernie sat]format ernie sat
ernie-sat/.meta/framework.png

306.0 KB | W: | H:

ernie-sat/.meta/framework.png

139.9 KB | W: | H:

ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
  • 2-up
  • Swipe
  • Onion skin
ERNIE-SAT是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架
ERNIE-SAT中我们提出了两项创新:
ERNIE-SAT 中我们提出了两项创新:
- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
- 采用语言和语音的联合掩码学习实现了语言和语音的对齐
......@@ -12,14 +12,14 @@ ERNIE-SAT中我们提出了两项创新:
### 1.安装飞桨与环境依赖
- 本项目的代码基于 Paddle(version>=2.0)
- 本项目开放提供加载torch版本的vocoder的功能
- 本项目开放提供加载 torch 版本的 vocoder 的功能
- torch version>=1.8
- 安装htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的htk(例如3.4.1)。同时提供[历史版本htk下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 1.注册账号,下载htk
- 2.解压htk文件,**放入项目根目录的tools文件夹中, 以htk文件夹名称放入**
- 3.**注意**: 如果您下载的是3.4.1或者更高版本,需要进入HTKLib/HRec.c文件中, **修改1626行和1650行**, 即把**以下两行的dur<=0 都修改为 dur<0**,如下所示:
- 1.注册账号,下载 htk
- 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash
以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
......@@ -28,26 +28,23 @@ ERNIE-SAT中我们提出了两项创新:
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
```
- 4.**编译**: 详情参见解压后的htk中的README文件(如果未编译, 则无法正常运行)
- 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
- 安装ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN项目并且安装相关依赖即可。
- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。
- 安装其他依赖: **sox, libsndfile**
### 2.预训练模型
预训练模型ERNIE-SAT的模型如下所示:
预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz)
创建download文件夹,下载上述ERNIE-SAT预训练模型并将其解压:
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```bash
mkdir pretrained_model
cd pretrained_model
......@@ -56,13 +53,12 @@ tar -zxvf model-ernie-sat-base-zh.tar.gz
tar -zxvf model-ernie-sat-base-en_zh.tar.gz
```
### 3.下载
1. 本项目使用parallel wavegan作为声码器(vocoder):
1. 本项目使用 parallel wavegan 作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash
mkdir download
......@@ -70,11 +66,11 @@ cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2. 本项目使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的fastspeech2模型并将其解压
下载上述预训练的 fastspeech2 模型并将其解压
```bash
cd download
......@@ -85,7 +81,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若use_pt_vocoder参数设置为False,则英文场景下使用paddle版本的声码器。
注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件
......@@ -114,19 +110,19 @@ prompt/dev
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的id
8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言
12. ` --target_language` , 目标语言
13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` use_pt_vocoder`,英文场景下是否使用torch版本的vocoder, 默认情况下为True; 设置为False则在英文场景下使用paddle版本vocoder
15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
运行以下脚本即可进行实验
```shell
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
```
#!/usr/bin/env python
""" Usage:
align_english.py wavfile trsfile outwordfile outphonefile
align_english.py wavfile trsfile outwordfile outphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
......@@ -48,7 +47,8 @@ def prep_txt(line, tmpbase, dictfile):
for unk in unk_words:
fwid.write(unk + '\n')
try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + '_unk.phons')
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except:
print('english2phoneme error!')
sys.exit(1)
......@@ -79,7 +79,7 @@ def prep_txt(line, tmpbase, dictfile):
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j+2]
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
......@@ -96,6 +96,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.write('\n')
fw.close()
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
......@@ -108,6 +109,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
......@@ -120,19 +122,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000
pen = (int(lines[i].split()[1])/1000+125)/10000
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
......@@ -151,8 +154,13 @@ def gen_res(tmpbase, outfile1, outfile2):
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
......@@ -160,7 +168,7 @@ def alignment(wav_path, text_string):
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
......@@ -179,14 +187,19 @@ def alignment(wav_path, text_string):
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp')
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null')
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
......@@ -207,15 +220,15 @@ def alignment(wav_path, text_string):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000
pen = (int(splited_line[1])/1000+125)/10000
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line)==5:
current_word = str(index)+'_'+splited_line[-1]
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index+=1
elif len(splited_line)==4:
word2phns[current_word] += ' '+phn
i+=1
return times2,word2phns
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
......@@ -19,7 +17,10 @@ def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']:
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
......@@ -43,6 +44,7 @@ def prep_txt(line, tmpbase, dictfile):
fwid.write('\n')
return unk_words
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
......@@ -55,6 +57,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
......@@ -67,19 +70,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000
pen = (int(lines[i].split()[1])/1000+125)/10000
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
......@@ -99,18 +103,18 @@ def gen_res(tmpbase, outfile1, outfile2):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + '.wav remix -')
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
......@@ -133,14 +137,19 @@ def alignment_zh(wav_path, text_string):
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp')
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null')
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
......@@ -156,23 +165,22 @@ def alignment_zh(wav_path, text_string):
i = 2
times2 = []
word2phns = {}
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000
pen = (int(splited_line[1])/1000+125)/10000
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line)==5:
current_word = str(index)+'_'+splited_line[-1]
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index+=1
elif len(splited_line)==4:
word2phns[current_word] += ' '+phn
i+=1
return times2,word2phns
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
import math
import paddle
import numpy as np
import math
import paddle
def pad_list(xs, pad_value):
......@@ -28,23 +26,25 @@ def pad_list(xs, pad_value):
"""
n_batch = len(xs)
max_len = max(paddle.shape(x)[0] for x in xs)
pad = paddle.full((n_batch, max_len), pad_value, dtype = xs[0].dtype)
pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype)
for i in range(n_batch):
pad[i, : paddle.shape(xs[i])[0]] = xs[i]
pad[i, :paddle.shape(xs[i])[0]] = xs[i]
return pad
def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros((n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch):
text_pad[i, : paddle.shape(text[i])[0]] = text[i]
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, : max_tlen]
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
......@@ -139,7 +139,6 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if not isinstance(lengths, list):
lengths = list(lengths)
# print('lengths', lengths)
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
......@@ -147,10 +146,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = paddle.expand(paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_range_expand = paddle.expand(
paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1)
# print('seq_length_expand', paddle.shape(seq_length_expand))
# print('seq_range_expand', paddle.shape(seq_range_expand))
mask = seq_range_expand >= seq_length_expand
if xs is not None:
......@@ -160,16 +158,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(len(paddle.shape(xs)))
)
# print('0:', paddle.shape(mask))
# print('1:', paddle.shape(mask[ind]))
# print('2:', paddle.shape(xs))
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
......@@ -259,8 +253,14 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
return ~make_pad_mask(lengths, xs, length_dim)
def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary=None):
def phones_masking(xs_pad,
src_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary=None):
bz, sent_len, _ = paddle.shape(xs_pad)
mask_num_lower = math.ceil(sent_len * mlm_prob)
masked_position = np.zeros((bz, sent_len))
......@@ -273,38 +273,41 @@ def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length*mlm_prob//3, 50)
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero()
masked_position[:,masked_phn_indices]=1
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_indices = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_position[:, masked_phn_indices] = 1
else:
for idx in range(bz):
if span_boundary is not None:
for s,e in zip(span_boundary[idx][::2], span_boundary[idx][1::2]):
for s, e in zip(span_boundary[idx][::2],
span_boundary[idx][1::2]):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
else:
length = align_start_lengths[idx].item()
if length<2:
if length < 2:
continue
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero()
masked_phn_indices = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_indices].tolist()
masked_end = align_end[idx][masked_phn_indices].tolist()
for s,e in zip(masked_start, masked_end):
for s, e in zip(masked_start, masked_end):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
non_eos_mask = np.array(paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
non_eos_mask = np.array(
paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
masked_position = masked_position * non_eos_mask
# y_masks = src_mask & y_masks.bool()
return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks
def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb):
def get_segment_pos(speech_pad, text_pad, align_start, align_end,
align_start_lengths, sega_emb):
bz, speech_len, _ = speech_pad.size()
_, text_len = text_pad.size()
......@@ -313,7 +316,6 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
text_segment_pos = np.zeros((bz, text_len)).astype('int64')
speech_segment_pos = np.zeros((bz, speech_len)).astype('int64')
if not sega_emb:
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
......@@ -321,11 +323,11 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
for idx in range(bz):
align_length = align_start_lengths[idx].item()
for j in range(align_length):
s,e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j+1
text_segment_pos[idx][j] = j+1
s, e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j + 1
text_segment_pos[idx][j] = j + 1
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos
\ No newline at end of file
return speech_segment_pos, text_segment_pos
#!/usr/bin/env python3
import os
from pathlib import Path
import librosa
import random
import soundfile as sf
import sys
import pickle
import argparse
import math
import os
import pickle
import random
import string
import sys
from pathlib import Path
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import paddle
import soundfile as sf
import torch
import math
import string
import numpy as np
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from read_text import read_2column_text,load_num_sequence_text
from utils import sentence2phns,get_voc_out, evaluate_durations, is_chinese, build_vocoder_from_file
from model_paddle import build_model_from_file
from sedit_arg_parser import parse_args
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from dataset import pad_list, pad_to_longformer_att_window, make_pad_mask, make_non_pad_mask, phones_masking, get_segment_pos
from align_english import alignment
from align_mandarin import alignment_zh
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from dataset import get_segment_pos
from dataset import make_non_pad_mask
from dataset import make_pad_mask
from dataset import pad_list
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from utils import sentence2phns
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
random.seed(0)
np.random.seed(0)
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str, use_pt_vocoder, duration_preditor_path,sid=None, non_autoreg=True):
def plot_mel_and_vocode_wav(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
use_pt_vocoder,
duration_preditor_path,
sid=None,
non_autoreg=True):
wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
use_teacher_forcing=non_autoreg,
sid=sid
)
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[1]].detach().float().cpu().numpy()
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
use_teacher_forcing=non_autoreg,
sid=sid)
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[
1]].detach().float().cpu().numpy()
if target_language == 'english':
if use_pt_vocoder:
output_feat = output_feat.detach().float().cpu().numpy()
output_feat = torch.tensor(output_feat,dtype=torch.float)
output_feat = torch.tensor(output_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).detach().float().data.cpu().numpy()
replaced_wav = vocoder(
output_feat).detach().float().data.cpu().numpy()
else:
output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav = get_voc_out(output_feat_np, target_language)
elif target_language == 'chinese':
output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_language)
old_time_boundary = [hop_length * x for x in old_span_boundary]
new_time_boundary = [hop_length * x for x in new_span_boundary]
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat,
target_language)
old_time_boundary = [hop_length * x for x in old_span_boundary]
new_time_boundary = [hop_length * x for x in new_span_boundary]
if target_language == 'english':
wav_org_replaced_paddle_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav[new_time_boundary[0]:new_time_boundary[1]], wav_org[old_time_boundary[1]:]])
wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_boundary[0]],
replaced_wav[new_time_boundary[0]:new_time_boundary[1]],
wav_org[old_time_boundary[1]:]
])
data_dict = {
"origin":wav_org,
"output":wav_org_replaced_paddle_voc}
data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
elif target_language == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[old_time_boundary[1]:]])
elif target_language == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_boundary[1]:]
])
data_dict = {
"origin":wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,}
return data_dict, old_span_boundary
"origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,
}
return data_dict, old_span_boundary
def get_unk_phns(word_str):
......@@ -97,7 +126,8 @@ def get_unk_phns(word_str):
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + 'temp.phons')
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
......@@ -116,7 +146,7 @@ def get_unk_phns(word_str):
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j+2]
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
......@@ -129,8 +159,9 @@ def get_unk_phns(word_str):
phns.extend(seq)
return phns
def words2phns(line):
dictfile = MODEL_DIR_EN+'/dict'
dictfile = MODEL_DIR_EN + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip()
words = []
......@@ -151,30 +182,33 @@ def words2phns(line):
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index)+"_"+wrd] = [wrd]
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index)+"_"+wrd.upper()] = get_unk_phns(wrd)
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index)+"_"+wrd.upper()] = word2phns_dict[wrd.upper()].split()
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line):
dictfile = MODEL_DIR_ZH+'/dict'
dictfile = MODEL_DIR_ZH + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']:
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
......@@ -183,7 +217,7 @@ def words2phns_zh(line):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
......@@ -192,17 +226,17 @@ def words2phns_zh(line):
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index)+"_"+wrd] = [wrd]
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index)+"_"+wrd] = word2phns_dict[wrd].split()
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
......@@ -212,62 +246,67 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag)
vocoder_config = Path(vocoder_file).parent / "config.yml"
vocoder = build_vocoder_from_file(
vocoder_config, vocoder_file, None, 'cpu'
)
vocoder = build_vocoder_from_file(vocoder_config, vocoder_file, None, 'cpu')
return vocoder
def load_model(model_name):
config_path='./pretrained_model/{}/config.yaml'.format(model_name)
config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file(config_file=config_path,
model_file=model_path)
mlm_model, args = build_model_from_file(
config_file=config_path, model_file=model_path)
return mlm_model, args
def read_data(uid,prefix):
mfa_text = read_2column_text(prefix+'/text')[uid]
mfa_wav_path = read_2column_text(prefix+'/wav.scp')[uid]
def read_data(uid, prefix):
mfa_text = read_2column_text(prefix + '/text')[uid]
mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
if 'mnt' not in mfa_wav_path:
mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path
return mfa_text, mfa_wav_path
def get_align_data(uid,prefix):
mfa_path = prefix+"mfa_"
mfa_text = read_2column_text(mfa_path+'text')[uid]
mfa_start = load_num_sequence_text(mfa_path+'start',loader_type='text_float')[uid]
mfa_end = load_num_sequence_text(mfa_path+'end',loader_type='text_float')[uid]
mfa_wav_path = read_2column_text(mfa_path+'wav.scp')[uid]
def get_align_data(uid, prefix):
mfa_path = prefix + "mfa_"
mfa_text = read_2column_text(mfa_path + 'text')[uid]
mfa_start = load_num_sequence_text(
mfa_path + 'start', loader_type='text_float')[uid]
mfa_end = load_num_sequence_text(
mfa_path + 'end', loader_type='text_float')[uid]
mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid]
return mfa_text, mfa_start, mfa_end, mfa_wav_path
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced):
align_start=paddle.to_tensor(mfa_start).unsqueeze(0)
align_end =paddle.to_tensor(mfa_end).unsqueeze(0)
align_start = paddle.floor(fs*align_start/hop_length).int()
align_end = paddle.floor(fs*align_end/hop_length).int()
if span_tobe_replaced[0]>=len(mfa_start):
span_boundary = [align_end[0].tolist()[-1],align_end[0].tolist()[-1]]
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length,
span_tobe_replaced):
align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
align_start = paddle.floor(fs * align_start / hop_length).int()
align_end = paddle.floor(fs * align_end / hop_length).int()
if span_tobe_replaced[0] >= len(mfa_start):
span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
else:
span_boundary=[align_start[0].tolist()[span_tobe_replaced[0]],align_end[0].tolist()[span_tobe_replaced[1]-1]]
span_boundary = [
align_start[0].tolist()[span_tobe_replaced[0]],
align_end[0].tolist()[span_tobe_replaced[1] - 1]
]
return span_boundary
def recover_dict(word2phns, tp_word2phns):
dic = {}
need_del_key = []
exist_index = []
sp_count = 0
add_sp_count = 0
exist_index = []
sp_count = 0
add_sp_count = 0
for key in word2phns.keys():
idx, wrd = key.split('_')
if wrd == 'sp':
sp_count += 1
sp_count += 1
exist_index.append(int(idx))
else:
need_del_key.append(key)
for key in need_del_key:
del word2phns[key]
......@@ -275,35 +314,36 @@ def recover_dict(word2phns, tp_word2phns):
for key in tp_word2phns.keys():
# print("debug: ",key)
if cur_id in exist_index:
dic[str(cur_id)+"_sp"] = 'sp'
cur_id += 1
add_sp_count += 1
dic[str(cur_id) + "_sp"] = 'sp'
cur_id += 1
add_sp_count += 1
idx, wrd = key.split('_')
dic[str(cur_id)+"_"+wrd] = tp_word2phns[key]
dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
cur_id += 1
if add_sp_count + 1 == sp_count:
dic[str(cur_id)+"_sp"] = 'sp'
add_sp_count += 1
dic[str(cur_id) + "_sp"] = 'sp'
add_sp_count += 1
assert add_sp_count == sp_count, "sp are not added in dic"
return dic
def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target_language):
def get_phns_and_spans(wav_path, old_str, new_str, source_language,
clone_target_language):
append_new_str = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], []
if source_language == "english":
times2,word2phns = alignment(wav_path, old_str)
times2, word2phns = alignment(wav_path, old_str)
elif source_language == "chinese":
times2,word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str)
for key,value in tp_word2phns.items():
times2, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str)
for key, value in tp_word2phns.items():
idx, wrd = key.split('_')
cur_val = " ".join(value)
tp_word2phns[key] = cur_val
tp_word2phns[key] = cur_val
word2phns = recover_dict(word2phns, tp_word2phns)
......@@ -315,9 +355,8 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
mfa_end.append(float(item[2]))
old_phns.append(item[0])
if append_new_str and (source_language != clone_target_language):
is_cross_lingual_clone = True
is_cross_lingual_clone = True
else:
is_cross_lingual_clone = False
......@@ -326,54 +365,59 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
new_str_append = new_str[len(old_str):]
if clone_target_language == "chinese":
new_phns_origin,new_origin_word2phns = words2phns(new_str_origin)
new_phns_append,temp_new_append_word2phns = words2phns_zh(new_str_append)
new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
new_phns_append, temp_new_append_word2phns = words2phns_zh(
new_str_append)
elif clone_target_language == "english":
new_phns_origin,new_origin_word2phns = words2phns_zh(new_str_origin) # 原始句子
new_phns_append,temp_new_append_word2phns = words2phns(new_str_append) # clone句子
new_phns_origin, new_origin_word2phns = words2phns_zh(
new_str_origin) # 原始句子
new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append) # clone句子
else:
assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append
new_append_word2phns = {}
length = len(new_origin_word2phns)
for key,value in temp_new_append_word2phns.items():
for key, value in temp_new_append_word2phns.items():
idx, wrd = key.split('_')
new_append_word2phns[str(int(idx)+length)+'_'+wrd] = value
new_word2phns = dict(list(new_origin_word2phns.items()) + list(new_append_word2phns.items()))
new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value
else:
new_word2phns = dict(
list(new_origin_word2phns.items()) + list(
new_append_word2phns.items()))
else:
if source_language == clone_target_language and clone_target_language == "english":
new_phns, new_word2phns = words2phns(new_str)
elif source_language == clone_target_language and clone_target_language == "chinese":
new_phns, new_word2phns = words2phns_zh(new_str)
else:
assert source_language == clone_target_language, "source language is not same with target language..."
span_tobe_replaced = [0,len(old_phns)-1]
span_tobe_added = [0,len(new_phns)-1]
span_tobe_replaced = [0, len(old_phns) - 1]
span_tobe_added = [0, len(new_phns) - 1]
left_index = 0
new_phns_left = []
sp_count = 0
# find the left different index
for key in word2phns.keys():
idx, wrd = key.split('_')
if wrd=='sp':
sp_count +=1
if wrd == 'sp':
sp_count += 1
new_phns_left.append('sp')
else:
idx = str(int(idx) - sp_count)
if idx+'_'+wrd in new_word2phns:
left_index+=len(new_word2phns[idx+'_'+wrd])
if idx + '_' + wrd in new_word2phns:
left_index += len(new_word2phns[idx + '_' + wrd])
new_phns_left.extend(word2phns[key].split())
else:
span_tobe_replaced[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left)
break
# reverse word2phns and new_word2phns
right_index = 0
new_phns_right = []
......@@ -381,7 +425,7 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0])
new_phns_middle = []
if append_new_str:
if append_new_str:
new_phns_right = []
new_phns_middle = new_phns[left_index:]
span_tobe_replaced[0] = len(new_phns_left)
......@@ -391,176 +435,306 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, clone_target
else:
for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_')
if wrd=='sp':
sp_count +=1
new_phns_right = ['sp']+new_phns_right
if wrd == 'sp':
sp_count += 1
new_phns_right = ['sp'] + new_phns_right
else:
idx = str(new_word2phns_max_index-(word2phns_max_index-int(idx)-sp_count))
if idx+'_'+wrd in new_word2phns:
right_index-=len(new_word2phns[idx+'_'+wrd])
idx = str(new_word2phns_max_index - (word2phns_max_index - int(
idx) - sp_count))
if idx + '_' + wrd in new_word2phns:
right_index -= len(new_word2phns[idx + '_' + wrd])
new_phns_right = word2phns[key].split() + new_phns_right
else:
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
new_phns_middle = new_phns[left_index:right_index]
span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle)
span_tobe_added[1] = len(new_phns_left) + len(
new_phns_middle)
if len(new_phns_middle) == 0:
span_tobe_added[1] = min(span_tobe_added[1]+1, len(new_phns))
span_tobe_added[0] = max(0, span_tobe_added[0]-1)
span_tobe_replaced[0] = max(0, span_tobe_replaced[0]-1)
span_tobe_replaced[1] = min(span_tobe_replaced[1]+1, len(old_phns))
span_tobe_added[1] = min(span_tobe_added[1] + 1,
len(new_phns))
span_tobe_added[0] = max(0, span_tobe_added[0] - 1)
span_tobe_replaced[0] = max(0,
span_tobe_replaced[0] - 1)
span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1,
len(old_phns))
break
new_phns = new_phns_left+new_phns_middle+new_phns_right
new_phns = new_phns_left + new_phns_middle + new_phns_right
return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added
def duration_adjust_factor(original_dur, pred_dur, phns):
length = 0
accumulate = 0
factor_list = []
for ori,pred,phn in zip(original_dur, pred_dur,phns):
if pred==0 or phn=='sp':
for ori, pred, phn in zip(original_dur, pred_dur, phns):
if pred == 0 or phn == 'sp':
continue
else:
factor_list.append(ori/pred)
factor_list.append(ori / pred)
factor_list = np.array(factor_list)
factor_list.sort()
if len(factor_list)<5:
if len(factor_list) < 5:
return 1
length = 2
return np.average(factor_list[length:-length])
def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str, new_str, wav_path,duration_preditor_path,sid=None, mask_reconstruct=False,duration_adjust=True,start_end_sp=False, train_args=None):
wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs'])
def prepare_features_with_duration(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
sid=None,
mask_reconstruct=False,
duration_adjust=True,
start_end_sp=False,
train_args=None):
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
fs = train_args.feats_extract_conf['fs']
hop_length = train_args.feats_extract_conf['hop_length']
mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(wav_path, old_str, new_str, source_language, target_language)
mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(
wav_path, old_str, new_str, source_language, target_language)
if start_end_sp:
if new_phns[-1]!='sp':
new_phns = new_phns+['sp']
if new_phns[-1] != 'sp':
new_phns = new_phns + ['sp']
if target_language == "english":
old_durations = evaluate_durations(old_phns, target_language=target_language)
old_durations = evaluate_durations(
old_phns, target_language=target_language)
elif target_language =="chinese":
elif target_language == "chinese":
if source_language == "english":
old_durations = evaluate_durations(old_phns, target_language=source_language)
old_durations = evaluate_durations(
old_phns, target_language=source_language)
elif source_language == "chinese":
old_durations = evaluate_durations(old_phns, target_language=source_language)
old_durations = evaluate_durations(
old_phns, target_language=source_language)
else:
assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..."
original_old_durations = [e-s for e,s in zip(mfa_end, mfa_start)]
original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str:
new_phns = old_phns
span_tobe_added = span_tobe_replaced
d_factor_left = duration_adjust_factor(original_old_durations[:span_tobe_replaced[0]],old_durations[:span_tobe_replaced[0]], old_phns[:span_tobe_replaced[0]])
d_factor_right = duration_adjust_factor(original_old_durations[span_tobe_replaced[1]:],old_durations[span_tobe_replaced[1]:], old_phns[span_tobe_replaced[1]:])
d_factor = (d_factor_left+d_factor_right)/2
new_durations_adjusted = [d_factor*i for i in old_durations]
d_factor_left = duration_adjust_factor(
original_old_durations[:span_tobe_replaced[0]],
old_durations[:span_tobe_replaced[0]],
old_phns[:span_tobe_replaced[0]])
d_factor_right = duration_adjust_factor(
original_old_durations[span_tobe_replaced[1]:],
old_durations[span_tobe_replaced[1]:],
old_phns[span_tobe_replaced[1]:])
d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations]
else:
if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations,old_durations, old_phns)
d_factor_paddle = duration_adjust_factor(original_old_durations,old_durations, old_phns)
d_factor = d_factor * 1.25
d_factor = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor_paddle = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_language == "english":
new_durations = evaluate_durations(new_phns, target_language=target_language)
if target_language == "english":
new_durations = evaluate_durations(
new_phns, target_language=target_language)
elif target_language =="chinese":
new_durations = evaluate_durations(new_phns, target_language=target_language)
elif target_language == "chinese":
new_durations = evaluate_durations(
new_phns, target_language=target_language)
new_durations_adjusted = [d_factor*i for i in new_durations]
new_durations_adjusted = [d_factor * i for i in new_durations]
if span_tobe_replaced[0]<len(old_phns) and old_phns[span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]:
new_durations_adjusted[span_tobe_added[0]] = original_old_durations[span_tobe_replaced[0]]
if span_tobe_replaced[1]<len(old_phns) and span_tobe_added[1]<len(new_phns):
if span_tobe_replaced[0] < len(old_phns) and old_phns[
span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]:
new_durations_adjusted[span_tobe_added[0]] = original_old_durations[
span_tobe_replaced[0]]
if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len(
new_phns):
if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]:
new_durations_adjusted[span_tobe_added[1]] = original_old_durations[span_tobe_replaced[1]]
new_span_duration_sum = sum(new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]])
old_span_duration_sum = sum(original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
new_durations_adjusted[span_tobe_added[
1]] = original_old_durations[span_tobe_replaced[1]]
new_span_duration_sum = sum(
new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]])
old_span_duration_sum = sum(
original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
new_mfa_start = mfa_start[:span_tobe_replaced[0]]
new_mfa_end = mfa_end[:span_tobe_replaced[0]]
for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]:
if len(new_mfa_end) ==0:
if len(new_mfa_end) == 0:
new_mfa_start.append(0)
new_mfa_end.append(i)
else:
new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1]+i)
new_mfa_start += [i+duration_offset for i in mfa_start[span_tobe_replaced[1]:]]
new_mfa_end += [i+duration_offset for i in mfa_end[span_tobe_replaced[1]:]]
new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [
i + duration_offset for i in mfa_start[span_tobe_replaced[1]:]
]
new_mfa_end += [
i + duration_offset for i in mfa_end[span_tobe_replaced[1]:]
]
# 3. get new wav
if span_tobe_replaced[0]>=len(mfa_start):
if span_tobe_replaced[0] >= len(mfa_start):
left_index = len(wav_org)
right_index = left_index
else:
left_index = int(np.floor(mfa_start[span_tobe_replaced[0]]*fs))
right_index = int(np.ceil(mfa_end[span_tobe_replaced[1]-1]*fs))
new_blank_wav = np.zeros((int(np.ceil(new_span_duration_sum*fs)),), dtype=wav_org.dtype)
new_wav_org = np.concatenate([wav_org[:left_index], new_blank_wav, wav_org[right_index:]])
left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs))
right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs))
new_blank_wav = np.zeros(
(int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate(
[wav_org[:left_index], new_blank_wav, wav_org[right_index:]])
# 4. get old and new mel span to be mask
old_span_boundary = get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92]
new_span_boundary=get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, hop_length, span_tobe_added) # [92, 174]
old_span_boundary = get_masked_mel_boundary(
mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92]
new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs,
hop_length,
span_tobe_added) # [92, 174]
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary
def prepare_features(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor, wav_path, old_str,new_str,duration_preditor_path, sid=None,duration_adjust=True,start_end_sp=False,
mask_reconstruct=False, train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, old_str,
new_str, wav_path,duration_preditor_path,sid=sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp,mask_reconstruct=mask_reconstruct, train_args = train_args)
speech = np.array(wav_org,dtype=np.float32)
align_start=np.array(mfa_start)
align_end =np.array(mfa_end)
def prepare_features(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
duration_adjust=True,
start_end_sp=False,
mask_reconstruct=False,
train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct,
train_args=train_args)
speech = np.array(wav_org, dtype=np.float32)
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array(list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
text = np.array(
list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
# print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_boundary = np.array(new_span_boundary)
batch=[('1', {"speech":speech,"align_start":align_start,"align_end":align_end,"text":text,"span_boundary":span_boundary})]
batch = [('1', {
"speech": speech,
"align_start": align_start,
"align_end": align_end,
"text": text,
"span_boundary": span_boundary
})]
return batch, old_span_boundary, new_span_boundary
def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False,duration_adjust=True,start_end_sp=False, train_args=None):
fs, hop_length = train_args.feats_extract_conf['fs'], train_args.feats_extract_conf['hop_length']
batch,old_span_boundary,new_span_boundary = prepare_features(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model,processor,wav_path,old_str,new_str,duration_preditor_path, sid,duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args=train_args)
def decode_with_model(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
duration_adjust=True,
start_end_sp=False,
train_args=None):
fs, hop_length = train_args.feats_extract_conf[
'fs'], train_args.feats_extract_conf['hop_length']
batch, old_span_boundary, new_span_boundary = prepare_features(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
feats = collate_fn(batch)[1]
if 'text_masked_position' in feats.keys():
feats.pop('text_masked_position')
for k, v in feats.items():
feats[k] = paddle.to_tensor(v)
rtn = mlm_model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen']
rtn = mlm_model.inference(
**feats,
span_boundary=new_span_boundary,
use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen']
if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat(output[1:-1]+[output[-1].squeeze()], axis=0).cpu()
output_feat = paddle.concat(
output[1:-1] + [output[-1].squeeze()], axis=0).cpu()
elif 0 not in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat([output[0].squeeze()]+output[1:-1], axis=0).cpu()
output_feat = paddle.concat(
[output[0].squeeze()] + output[1:-1], axis=0).cpu()
elif 0 in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(output[1:-1], axis=0).cpu()
else:
output_feat = paddle.concat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], axis=0).cpu()
wav_org, rate = librosa.load(wav_path, sr=train_args.feats_extract_conf['fs'])
origin_speech = paddle.to_tensor(np.array(wav_org,dtype=np.float32)).unsqueeze(0)
output_feat = paddle.concat(
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0).cpu()
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
origin_speech = paddle.to_tensor(
np.array(wav_org, dtype=np.float32)).unsqueeze(0)
speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0)
return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length
......@@ -568,71 +742,64 @@ def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, tar
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
feats_extract,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False
):
self.mlm_prob=mlm_prob
self.mean_phn_span=mean_phn_span
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window=attention_window
self.pad_speech=pad_speech
self.sega_emb=sega_emb
self.attention_window = attention_window
self.pad_speech = pad_speech
self.sega_emb = sega_emb
self.duration_collect = duration_collect
self.text_masking = text_masking
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
sega_emb=self.sega_emb,
duration_collect=self.duration_collect,
text_masking=self.text_masking
)
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
mlm_prob: float = 0.8,
mean_phn_span: int = 8,
feats_extract=None,
attention_window: int = 0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
......@@ -654,9 +821,8 @@ def mlm_collate_fn(
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
assert all(not k.endswith("_lengths")
for k in data[0]), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
......@@ -679,7 +845,8 @@ def mlm_collate_fn(
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor([d[key].shape[0] for d in data], dtype=paddle.long)
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.long)
output[key + "_lengths"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
......@@ -689,71 +856,73 @@ def mlm_collate_fn(
feats = paddle.unsqueeze(feats, 0)
batch_size = paddle.shape(feats)[0]
if 'text' not in output:
text=paddle.zeros_like(feats_lengths.unsqueeze(-1))-2
text_lengths=paddle.zeros_like(feats_lengths)+1
max_tlen=1
align_start=paddle.zeros_like(text)
align_end=paddle.zeros_like(text)
align_start_lengths=paddle.zeros_like(feats_lengths)
align_end_lengths=paddle.zeros_like(feats_lengths)
sega_emb=False
text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2
text_lengths = paddle.zeros_like(feats_lengths) + 1
max_tlen = 1
align_start = paddle.zeros_like(text)
align_end = paddle.zeros_like(text)
align_start_lengths = paddle.zeros_like(feats_lengths)
align_end_lengths = paddle.zeros_like(feats_lengths)
sega_emb = False
mean_phn_span = 0
mlm_prob = 0.15
else:
text, text_lengths = output["text"], output["text_lengths"]
align_start, align_start_lengths, align_end, align_end_lengths = output["align_start"], output["align_start_lengths"], output["align_end"], output["align_end_lengths"]
align_start = paddle.floor(feats_extract.sr*align_start/feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr*align_end/feats_extract.hop_length).int()
align_start, align_start_lengths, align_end, align_end_lengths = output[
"align_start"], output["align_start_lengths"], output[
"align_end"], output["align_end_lengths"]
align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int()
max_tlen = max(text_lengths).item()
max_slen = max(feats_lengths).item()
speech_pad = feats[:, : max_slen]
if attention_window>0 and pad_speech:
speech_pad,max_slen = pad_to_longformer_att_window(speech_pad, max_slen, max_slen, attention_window)
speech_pad = feats[:, :max_slen]
if attention_window > 0 and pad_speech:
speech_pad, max_slen = pad_to_longformer_att_window(
speech_pad, max_slen, max_slen, attention_window)
max_len = max_slen + max_tlen
if attention_window>0:
text_pad, max_tlen = pad_to_longformer_att_window(text, max_len, max_tlen, attention_window)
if attention_window > 0:
text_pad, max_tlen = pad_to_longformer_att_window(
text, max_len, max_tlen, attention_window)
else:
text_pad = text
text_mask = make_non_pad_mask(text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2)
if attention_window>0:
text_mask = text_mask*2
speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:,0], length_dim=1).unsqueeze(-2)
text_mask = make_non_pad_mask(
text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2)
if attention_window > 0:
text_mask = text_mask * 2
speech_mask = make_non_pad_mask(
feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_boundary = None
if 'span_boundary' in output.keys():
span_boundary = output['span_boundary']
if text_masking:
masked_position, text_masked_position,_ = phones_text_masking(
speech_pad,
speech_mask,
text_pad,
text_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
masked_position, text_masked_position, _ = phones_text_masking(
speech_pad, speech_mask, text_pad, text_mask, align_start,
align_end, align_start_lengths, mlm_prob, mean_phn_span,
span_boundary)
else:
text_masked_position = np.zeros(text_pad.size())
masked_position, _ = phones_masking(
speech_pad,
speech_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary)
speech_pad, speech_mask, align_start, align_end,
align_start_lengths, mlm_prob, mean_phn_span, span_boundary)
output_dict = {}
if duration_collect and 'text' in output:
reordered_index, speech_segment_pos,text_segment_pos, durations,feats_lengths = get_segment_pos_reduce_duration(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb, masked_position, feats_lengths)
speech_mask = make_non_pad_mask(feats_lengths.tolist(), speech_pad[:,:reordered_index.shape[1],0], length_dim=1).unsqueeze(-2)
reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb, masked_position, feats_lengths)
speech_mask = make_non_pad_mask(
feats_lengths.tolist(),
speech_pad[:, :reordered_index.shape[1], 0],
length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations
output_dict['reordered_index'] = reordered_index
else:
speech_segment_pos, text_segment_pos = get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb)
speech_segment_pos, text_segment_pos = get_segment_pos(
speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_position'] = masked_position
......@@ -767,9 +936,8 @@ def mlm_collate_fn(
output = (uttids, output_dict)
return output
def build_collate_fn(
args: argparse.Namespace, train: bool, epoch=-1
):
def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, torch.Tensor]],
......@@ -793,68 +961,142 @@ def build_collate_fn(
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
if args.encoder_conf['selfattention_layer_type'] == 'longformer':
attention_window = args.encoder_conf['attention_window']
pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf['pre_speech_layer'] >0 else False
pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
'pre_speech_layer'] > 0 else False
else:
attention_window=0
pad_speech=False
if epoch==-1:
attention_window = 0
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5]
mlm_prob_factor = 0.8 #mlm_probs[epoch // 100]
if 'duration_predictor_layers' in args.model_conf.keys() and args.model_conf['duration_predictor_layers']>0:
duration_collect=True
mlm_prob_factor = 0.8 #mlm_probs[epoch // 100]
if 'duration_predictor_layers' in args.model_conf.keys(
) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True
else:
duration_collect=False
return MLMCollateFn(feats_extract, float_pad_value=0.0, int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob']*mlm_prob_factor,mean_phn_span=args.model_conf['mean_phn_span'],attention_window=attention_window,pad_speech=pad_speech,sega_emb=sega_emb,duration_collect=duration_collect)
duration_collect = False
def get_mlm_output(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path, old_str, new_str,duration_preditor_path, sid=None, decoder=False,use_teacher_forcing=False, dynamic_eval=(0,0),duration_adjust=True,start_end_sp=False):
mlm_model,train_args = load_model(model_name)
return MLMCollateFn(
feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
mean_phn_span=args.model_conf['mean_phn_span'],
attention_window=attention_window,
pad_speech=pad_speech,
sega_emb=sega_emb,
duration_collect=duration_collect)
def get_mlm_output(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
dynamic_eval=(0, 0),
duration_adjust=True,
start_end_sp=False):
mlm_model, train_args = load_model(model_name)
mlm_model.eval()
processor = None
collate_fn = build_collate_fn(train_args, False)
return decode_with_model(uid,prefix, clone_uid, clone_prefix, source_language, target_language, mlm_model, processor, collate_fn, wav_path, old_str, new_str,duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing,
duration_adjust=duration_adjust,start_end_sp=start_end_sp, train_args = train_args)
def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vocoder, prefix='dump/raw/dev', model_name="conformer", old_str="",new_str="",prompt_decoding=False,dynamic_eval=(0,0), task_name = None):
return decode_with_model(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=sid,
decoder=decoder,
use_teacher_forcing=use_teacher_forcing,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
def test_vctk(uid,
clone_uid,
clone_prefix,
source_language,
target_language,
vocoder,
prefix='dump/raw/dev',
model_name="conformer",
old_str="",
new_str="",
prompt_decoding=False,
dynamic_eval=(0, 0),
task_name=None):
duration_preditor_path = None
spemd = None
full_origin_str,wav_path = read_data(uid, prefix)
full_origin_str, wav_path = read_data(uid, prefix)
if task_name == 'edit':
new_str = new_str
elif task_name == 'synthesize':
new_str = full_origin_str + new_str
new_str = full_origin_str + new_str
else:
new_str = full_origin_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
new_str = full_origin_str + ' '.join(
[ch for ch in new_str if is_chinese(ch)])
print('new_str is ', new_str)
if not old_str:
old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_language, target_language, model_name, wav_path,full_origin_str, old_str, new_str,vocoder,duration_preditor_path,sid=spemd)
results_dict, old_span = plot_mel_and_vocode_wav(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
vocoder,
duration_preditor_path,
sid=spemd)
return results_dict
if __name__ == "__main__":
# parse config and args
args = parse_args()
print(args)
data_dict = test_vctk(args.uid,
args.clone_uid,
args.clone_prefix,
args.source_language,
args.target_language,
data_dict = test_vctk(
args.uid,
args.clone_uid,
args.clone_prefix,
args.source_language,
args.target_language,
args.use_pt_vocoder,
args.prefix,
args.prefix,
args.model_name,
new_str=args.new_str,
task_name=args.task_name)
sf.write('./wavs/%s' % args.output_name, data_dict['output'], samplerate=24000)
sf.write(args.output_name, data_dict['output'], samplerate=24000)
print("finished...")
# exit()
import argparse
import logging
import math
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, Optional
from typing import Dict
from typing import List
from typing import Sequence
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
from matplotlib.collections import Collection
from matplotlib.pyplot import axis
import librosa
import soundfile as sf
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from typeguard import check_argument_types
import logging
import math
import yaml
from abc import ABC, abstractmethod
import warnings
from paddle.amp import auto_cast
import sys, os
from paddle import nn
pypath = '..'
for dir_name in os.listdir(pypath):
dir_path = os.path.join(pypath, dir_name)
if os.path.isdir(dir_path):
sys.path.append(dir_path)
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictor
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.layer_norm import LayerNorm
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
class Swish(nn.Layer):
"""Construct an Swish object."""
def forward(self, x):
"""Return Swich activation function."""
return x * F.sigmoid(x)
def get_activation(act):
"""Return activation function."""
activation_funcs = {
"hardtanh": nn.Hardtanh,
"tanh": nn.Tanh,
"relu": nn.ReLU,
"selu": nn.SELU,
"swish": Swish,
}
return activation_funcs[act]()
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
......@@ -89,6 +53,7 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
......@@ -102,20 +67,18 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
# if self.pe.dtype != x.dtype or self.pe.device != x.device:
# self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0, dtype=paddle.float32
).unsqueeze(1)
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32)
* -(math.log(10000.0) / self.d_model)
)
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
......@@ -129,46 +92,11 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
def dump_tensor(var, do_trans = False):
wf = open('/mnt/home/xiaoran/PaddleSpeech-develop/tmp_var.out', 'w')
for num in var.shape:
wf.write(str(num) + ' ')
wf.write('\n')
if do_trans:
var = paddle.transpose(var, [1,0])
if len(var.shape)==1:
for _var in var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
elif len(var.shape)==2:
for __var in var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
elif len(var.shape)==3:
for ___var in var:
for __var in ___var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
wf.write('\n')
elif len(var.shape)==4:
for ____var in var:
for ___var in ____var:
for __var in ___var:
for _var in __var:
s = ("%.10f"%_var.item())
wf.write(s+' ')
wf.write('\n')
wf.write('\n')
wf.write('\n')
class mySequential(nn.Sequential):
def forward(self, *inputs):
......@@ -179,24 +107,29 @@ class mySequential(nn.Sequential):
inputs = module(inputs)
return inputs
class NewMaskInputLayer(nn.Layer):
__constants__ = ['out_features']
out_features: int
def __init__(self, out_features: int,
device=None, dtype=None) -> None:
def __init__(self, out_features: int, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(NewMaskInputLayer, self).__init__()
super().__init__()
self.mask_feature = paddle.create_parameter(
shape=(1,1,out_features),
dtype=paddle.float32,
default_initializer=paddle.nn.initializer.Assign(paddle.normal(shape=(1,1,out_features))))
def forward(self, input: paddle.Tensor, masked_position=None) -> paddle.Tensor:
masked_position = paddle.expand_as(paddle.unsqueeze(masked_position, -1), input)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
shape=(1, 1, out_features),
dtype=paddle.float32,
default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor,
masked_position=None) -> paddle.Tensor:
masked_position = paddle.expand_as(
paddle.unsqueeze(masked_position, -1), input)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
return masked_input
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
......@@ -266,7 +199,8 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
......@@ -278,17 +212,20 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u, paddle.transpose(k, [0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v, paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class MLMEncoder(nn.Layer):
"""Conformer encoder module.
......@@ -324,40 +261,39 @@ class MLMEncoder(nn.Layer):
signature.)
"""
def __init__(
self,
idim,
vocab_size=0,
pre_speech_layer: int = 0,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
pos_enc_class=None,
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
zero_triu=False,
cnn_module_kernel=31,
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
text_masking = False
):
def __init__(self,
idim,
vocab_size=0,
pre_speech_layer: int=0,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
pos_enc_class=None,
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
zero_triu=False,
cnn_module_kernel=31,
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
text_masking=False):
"""Construct an Encoder object."""
super(MLMEncoder, self).__init__()
super().__init__()
self._output_size = attention_dim
self.text_masking=text_masking
self.text_masking = text_masking
if self.text_masking:
self.text_masking_layer = NewMaskInputLayer(attention_dim)
activation = get_activation(activation_type)
......@@ -381,21 +317,18 @@ class MLMEncoder(nn.Layer):
nn.LayerNorm(attention_dim),
nn.Dropout(dropout_rate),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
pos_enc_class(attention_dim, positional_dropout_rate), )
self.conv_subsampling_factor = 4
elif input_layer == "embed":
self.embed = nn.Sequential(
nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "mlm":
self.segment_emb = None
self.speech_embed = mySequential(
......@@ -403,34 +336,31 @@ class MLMEncoder(nn.Layer):
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
pos_enc_class(attention_dim, positional_dropout_rate))
self.text_embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer=="sega_mlm":
self.segment_emb = nn.Embedding(500, attention_dim, padding_idx=padding_idx)
nn.Embedding(
vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "sega_mlm":
self.segment_emb = nn.Embedding(
500, attention_dim, padding_idx=padding_idx)
self.speech_embed = mySequential(
NewMaskInputLayer(idim),
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
pos_enc_class(attention_dim, positional_dropout_rate))
self.text_embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
nn.Embedding(
vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif isinstance(input_layer, nn.Layer):
self.embed = nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer is None:
self.embed = nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
pos_enc_class(attention_dim, positional_dropout_rate))
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
......@@ -439,57 +369,39 @@ class MLMEncoder(nn.Layer):
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
elif selfattention_layer_type == "rel_selfattn":
logging.info("encoder self-attention layer type = relative self-attention")
logging.info(
"encoder self-attention layer type = relative self-attention")
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
zero_triu,
)
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, zero_triu, )
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
raise ValueError("unknown encoder_attn_layer: " +
selfattention_layer_type)
# feed-forward module definition
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
attention_dim,
linear_units,
dropout_rate,
activation,
)
positionwise_layer_args = (attention_dim, linear_units,
dropout_rate, activation, )
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
positionwise_layer_args = (attention_dim, linear_units,
positionwise_conv_kernel_size,
dropout_rate, )
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
positionwise_layer_args = (attention_dim, linear_units,
positionwise_conv_kernel_size,
dropout_rate, )
else:
raise NotImplementedError("Support only linear or conv1d.")
......@@ -508,9 +420,7 @@ class MLMEncoder(nn.Layer):
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks,
),
)
stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
self.pre_speech_layer = pre_speech_layer
self.pre_speech_encoders = repeat(
self.pre_speech_layer,
......@@ -523,16 +433,21 @@ class MLMEncoder(nn.Layer):
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer,
),
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
def forward(self, speech_pad, text_pad, masked_position, speech_mask=None, text_mask=None,speech_segment_pos=None, text_segment_pos=None):
def forward(self,
speech_pad,
text_pad,
masked_position,
speech_mask=None,
text_mask=None,
speech_segment_pos=None,
text_segment_pos=None):
"""Encode input sequence.
"""
......@@ -542,12 +457,13 @@ class MLMEncoder(nn.Layer):
speech_pad = self.speech_embed(speech_pad)
# pure speech input
if -2 in np.array(text_pad):
text_pad = text_pad+3
text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_segment_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), text_pad[1])
text_segment_pos=None
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos),
text_pad[1])
text_segment_pos = None
elif text_pad is not None:
text_pad = self.text_embed(text_pad)
segment_emb = None
......@@ -556,32 +472,32 @@ class MLMEncoder(nn.Layer):
text_segment_emb = self.segment_emb(text_segment_pos)
text_pad = (text_pad[0] + text_segment_emb, text_pad[1])
speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1])
segment_emb = paddle.concat([speech_segment_emb, text_segment_emb],axis=1)
segment_emb = paddle.concat(
[speech_segment_emb, text_segment_emb], axis=1)
if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
if text_pad is not None:
xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1)
xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1)
masks = paddle.concat([speech_mask,text_mask],axis=-1)
masks = paddle.concat([speech_mask, text_mask], axis=-1)
else:
xs = speech_pad[0]
xs_pos_emb = speech_pad[1]
masks = speech_mask
xs, masks = self.encoders((xs,xs_pos_emb), masks)
xs, masks = self.encoders((xs, xs_pos_emb), masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks #, segment_emb
return xs, masks #, segment_emb
class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_position=None,segment_emb=None):
def forward(self, xs, masks, masked_position=None, segment_emb=None):
"""Encode input sequence.
Args:
......@@ -596,9 +512,6 @@ class MLMDecoder(MLMEncoder):
emb, mlm_position = None, None
if not self.training:
masked_position = None
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
# xs, masks = self.embed(xs, masks)
# else:
xs = self.embed(xs)
if segment_emb:
xs = (xs[0] + segment_emb, xs[1])
......@@ -609,10 +522,8 @@ class MLMDecoder(MLMEncoder):
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (
self.intermediate_layers is not None
and layer_idx + 1 in self.intermediate_layers
):
if (self.intermediate_layers is not None and
layer_idx + 1 in self.intermediate_layers):
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
......@@ -627,104 +538,44 @@ class MLMDecoder(MLMEncoder):
return xs, masks, intermediate_outputs
return xs, masks
class AbsESPnetModel(nn.Layer, ABC):
"""The common abstract class among each tasks
"ESPnetModel" is referred to a class which inherits paddle.nn.Layer,
and makes the dnn-models forward as its member field,
a.k.a delegate pattern,
and defines "loss", "stats", and "weight" for the task.
If you intend to implement new task in ESPNet,
the model must inherit this class.
In other words, the "mediator" objects between
our training system and the your task class are
just only these three values, loss, stats, and weight.
Example:
>>> from espnet2.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
"""
@abstractmethod
def forward(
self, **batch: paddle.Tensor
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
raise NotImplementedError
@abstractmethod
def collect_feats(self, **batch: paddle.Tensor) -> Dict[str, paddle.Tensor]:
raise NotImplementedError
class AbsFeatsExtract(nn.Layer, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_parameters(self) -> Dict[str, Any]:
raise NotImplementedError
@abstractmethod
def forward(
self, input: paddle.Tensor, input_lengths: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
raise NotImplementedError
class AbsNormalize(nn.Layer, ABC):
@abstractmethod
def forward(
self, input: paddle.Tensor, input_lengths: paddle.Tensor = None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
# return output, output_lengths
raise NotImplementedError
def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(shape = (n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
text_pad = paddle.zeros(
shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]),
dtype=text.dtype)
for i in range(n_batch):
text_pad[i, : paddle.shape(text[i])[0]] = text[i]
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, : max_tlen]
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
class ESPnetMLMModel(AbsESPnetModel):
def __init__(
self,
token_list: Union[Tuple[str, ...], List[str]],
odim: int,
feats_extract: Optional[AbsFeatsExtract],
normalize: Optional[AbsNormalize],
encoder: nn.Layer,
decoder: Optional[nn.Layer],
postnet_layers: int = 0,
postnet_chans: int = 0,
postnet_filts: int = 0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
masking_schema: str = "span",
mean_phn_span: int = 3,
mlm_prob: float = 0.25,
dynamic_mlm_prob = False,
decoder_seg_pos=False,
text_masking=False
):
class MLMModel(nn.Layer):
def __init__(self,
token_list: Union[Tuple[str, ...], List[str]],
odim: int,
encoder: nn.Layer,
decoder: Optional[nn.Layer],
postnet_layers: int=0,
postnet_chans: int=0,
postnet_filts: int=0,
ignore_id: int=-1,
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
report_cer: bool=True,
report_wer: bool=True,
sym_space: str="<space>",
sym_blank: str="<blank>",
masking_schema: str="span",
mean_phn_span: int=3,
mlm_prob: float=0.25,
dynamic_mlm_prob=False,
decoder_seg_pos=False,
text_masking=False):
super().__init__()
# note that eos is the same as sos (equivalent ID)
......@@ -732,105 +583,119 @@ class ESPnetMLMModel(AbsESPnetModel):
self.ignore_id = ignore_id
self.token_list = token_list.copy()
self.normalize = normalize
self.encoder = encoder
self.decoder = decoder
self.vocab_size = encoder.text_embed[0]._num_embeddings
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
token_list, sym_space, sym_blank, report_cer, report_wer)
else:
self.error_calculator = None
self.feats_extract = feats_extract
self.mlm_weight = 1.0
self.mlm_prob = mlm_prob
self.mlm_layer = 12
self.finetune_wo_mlm =True
self.finetune_wo_mlm = True
self.max_span = 50
self.min_span = 4
self.mean_phn_span = mean_phn_span
self.masking_schema = masking_schema
if self.decoder is None or not (hasattr(self.decoder, 'output_layer') and self.decoder.output_layer is not None):
if self.decoder is None or not (hasattr(self.decoder,
'output_layer') and
self.decoder.output_layer is not None):
self.sfc = nn.Linear(self.encoder._output_size, odim)
else:
self.sfc=None
self.sfc = None
if text_masking:
self.text_sfc = nn.Linear(self.encoder.text_embed[0]._embedding_dim, self.vocab_size, weight_attr = self.encoder.text_embed[0]._weight_attr)
self.text_sfc = nn.Linear(
self.encoder.text_embed[0]._embedding_dim,
self.vocab_size,
weight_attr=self.encoder.text_embed[0]._weight_attr)
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
else:
self.text_sfc = None
self.text_mlm_loss = None
self.decoder_seg_pos = decoder_seg_pos
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss(reduce=False)
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.postnet = (
None
if postnet_layers == 0
else Postnet(
idim=self.encoder._output_size,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=True,
dropout_rate=0.5,
)
)
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=self.encoder._output_size,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=True,
dropout_rate=0.5, ))
def collect_feats(self,
speech, speech_lengths, text, text_lengths, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos, y_masks=None
) -> Dict[str, paddle.Tensor]:
speech,
speech_lengths,
text,
text_lengths,
masked_position,
speech_mask,
text_mask,
speech_segment_pos,
text_segment_pos,
y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lengths": speech_lengths}
def _forward(self, batch, speech_segment_pos,y_masks=None):
def forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
if self.decoder is not None:
ys_in = self._add_first_frame_and_remove_last_frame(batch['speech_pad'])
ys_in = self._add_first_frame_and_remove_last_frame(
batch['speech_pad'])
encoder_out, h_masks = self.encoder(**batch)
if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out, bool(h_masks), self.encoder.segment_emb(speech_segment_pos))
zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks),
self.encoder.segment_emb(speech_segment_pos))
speech_hidden_states = zs
else:
speech_hidden_states = encoder_out[:,:paddle.shape(batch['speech_pad'])[1], :]
speech_hidden_states = encoder_out[:, :paddle.shape(batch[
'speech_pad'])[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim))
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet(
paddle.transpose(before_outs, [0, 2, 1])
), (0, 2, 1))
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
(0, 2, 1))
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch['masked_position']
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
def inference(
self,
speech, text, masked_position, speech_mask, text_mask, speech_segment_pos, text_segment_pos,
span_boundary,
y_masks=None,
speech_lengths=None, text_lengths=None,
feats: Optional[paddle.Tensor] = None,
spembs: Optional[paddle.Tensor] = None,
sids: Optional[paddle.Tensor] = None,
lids: Optional[paddle.Tensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 10.0,
use_teacher_forcing: bool = False,
) -> Dict[str, paddle.Tensor]:
self,
speech,
text,
masked_position,
speech_mask,
text_mask,
speech_segment_pos,
text_segment_pos,
span_boundary,
y_masks=None,
speech_lengths=None,
text_lengths=None,
feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
threshold: float=0.5,
minlenratio: float=0.0,
maxlenratio: float=10.0,
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
batch = dict(
speech_pad=speech,
text_pad=text,
......@@ -838,119 +703,130 @@ class ESPnetMLMModel(AbsESPnetModel):
speech_mask=speech_mask,
text_mask=text_mask,
speech_segment_pos=speech_segment_pos,
text_segment_pos=text_segment_pos,
)
text_segment_pos=text_segment_pos, )
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:,:span_boundary[0]]]
outs = [batch['speech_pad'][:, :span_boundary[0]]]
z_cache = None
if use_teacher_forcing:
before,zs, _, _ = self._forward(
before, zs, _, _ = self.forward(
batch, speech_segment_pos, y_masks=y_masks)
if zs is None:
zs = before
outs+=[zs[0][span_boundary[0]:span_boundary[1]]]
outs+=[batch['speech_pad'][:,span_boundary[1]:]]
outs += [zs[0][span_boundary[0]:span_boundary[1]]]
outs += [batch['speech_pad'][:, span_boundary[1]:]]
return dict(feat_gen=outs)
# concatenate attention weights -> (#layers, #heads, T_feats, T_text)
att_ws = paddle.stack(att_ws, axis=0)
outs += [batch['speech_pad'][:,span_boundary[1]:]]
return dict(feat_gen=outs, att_w=att_ws)
return None
def _add_first_frame_and_remove_last_frame(self, ys: paddle.Tensor) -> paddle.Tensor:
def _add_first_frame_and_remove_last_frame(
self, ys: paddle.Tensor) -> paddle.Tensor:
ys_in = paddle.concat(
[paddle.zeros(shape = (paddle.shape(ys)[0], 1, paddle.shape(ys)[2]), dtype = ys.dtype), ys[:, :-1]], axis=1
)
[
paddle.zeros(
shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]),
dtype=ys.dtype), ys[:, :-1]
],
axis=1)
return ys_in
class ESPnetMLMEncAsDecoderModel(ESPnetMLMModel):
def _forward(self, batch, speech_segment_pos, y_masks=None):
class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb
encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :]
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states), (paddle.shape(speech_hidden_states)[0], -1, self.odim))
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet(
paddle.transpose(before_outs, [0, 2, 1])
), [0, 2, 1])
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch['masked_position']
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
class ESPnetMLMDualMaksingModel(ESPnetMLMModel):
def _calc_mlm_loss(
self,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
text_outs: paddle.Tensor,
batch
):
class MLMDualMaksingModel(MLMModel):
def _calc_mlm_loss(self,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
text_outs: paddle.Tensor,
batch):
xs_pad = batch['speech_pad']
text_pad = batch['text_pad']
masked_position = batch['masked_position']
text_masked_position = batch['text_masked_position']
mlm_loss_position = masked_position>0
loss = paddle.sum(self.l1_loss_func(paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))), axis=-1)
mlm_loss_position = masked_position > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(self.l1_loss_func(paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))), axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(mlm_loss_position, axis=-1).float())) \
/ paddle.sum((mlm_loss_position.float()) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(paddle.reshape(text_outs, (-1,self.vocab_size)), paddle.reshape(text_pad, (-1))) * paddle.reshape(text_masked_position, (-1)).float())) \
/ paddle.sum((text_masked_position.float()) + 1e-10)
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_position,
(-1)))) / paddle.sum((text_masked_position) + 1e-10)
return loss_mlm, loss_text
def _forward(self, batch, speech_segment_pos, y_masks=None):
def forward(self, batch, speech_segment_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb
encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:,:paddle.shape(batch['speech_pad'])[1], :]
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
if self.text_sfc:
text_hiddent_states = zs[:,paddle.shape(batch['speech_pad'])[1]:,:]
text_outs = paddle.reshape(self.text_sfc(text_hiddent_states), (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[
1]:, :]
text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
if self.sfc is not None:
before_outs = paddle.reshape(self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(self.postnet(
paddle.transpose(before_outs, [0,2,1])
), [0, 2, 1])
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs,text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position']
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position']
def build_model_from_file(config_file, model_file):
state_dict = paddle.load(model_file)
model_class = ESPnetMLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else ESPnetMLMEncAsDecoderModel
model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else MLMEncAsDecoderModel
# 构建模型
args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8"))
......@@ -962,7 +838,8 @@ def build_model_from_file(config_file, model_file):
return model, args
def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderModel) -> ESPnetMLMModel:
def build_model(args: argparse.Namespace,
model_class=MLMEncAsDecoderModel) -> MLMModel:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
......@@ -975,17 +852,14 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
odim = 80
# Normalization layer
normalize = None
odim = 80
pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding
if "conformer" == args.encoder:
conformer_self_attn_layer_type = args.encoder_conf['selfattention_layer_type']
conformer_self_attn_layer_type = args.encoder_conf[
'selfattention_layer_type']
conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type']
conformer_rel_pos_type = "legacy"
if conformer_rel_pos_type == "legacy":
......@@ -994,38 +868,42 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
logging.warning(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
"please use conformer_pos_enc_layer_type = 'latest'.")
if conformer_self_attn_layer_type == "rel_selfattn":
conformer_self_attn_layer_type = "legacy_rel_selfattn"
logging.warning(
"Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
"please use conformer_pos_enc_layer_type = 'latest'.")
elif conformer_rel_pos_type == "latest":
assert conformer_pos_enc_layer_type != "legacy_rel_pos"
assert conformer_self_attn_layer_type != "legacy_rel_selfattn"
else:
raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}")
args.encoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type
args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type
if "conformer"==args.decoder:
args.decoder_conf['selfattention_layer_type'] = conformer_self_attn_layer_type
args.decoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type
args.encoder_conf[
'selfattention_layer_type'] = conformer_self_attn_layer_type
args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type
if "conformer" == args.decoder:
args.decoder_conf[
'selfattention_layer_type'] = conformer_self_attn_layer_type
args.decoder_conf[
'pos_enc_layer_type'] = conformer_pos_enc_layer_type
# Encoder
encoder_class = MLMEncoder
if 'text_masking' in args.model_conf.keys() and args.model_conf['text_masking']:
if 'text_masking' in args.model_conf.keys() and args.model_conf[
'text_masking']:
args.encoder_conf['text_masking'] = True
else:
args.encoder_conf['text_masking'] = False
encoder = encoder_class(args.input_size,vocab_size=vocab_size, pos_enc_class=pos_enc_class,
**args.encoder_conf)
encoder = encoder_class(
args.input_size,
vocab_size=vocab_size,
pos_enc_class=pos_enc_class,
**args.encoder_conf)
# Decoder
if args.decoder != 'no_decoder':
......@@ -1033,22 +911,17 @@ def build_model(args: argparse.Namespace, model_class = ESPnetMLMEncAsDecoderMod
decoder = decoder_class(
idim=0,
input_layer=None,
**args.decoder_conf,
)
**args.decoder_conf, )
else:
decoder = None
# Build model
model = model_class(
feats_extract=None, # maybe should be LogMelFbank
odim=odim,
normalize=normalize,
encoder=encoder,
decoder=decoder,
token_list=token_list,
**args.model_conf,
)
**args.model_conf, )
# Initialize
if args.init is not None:
......
p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525
Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125
p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825
p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125
Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625
p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325
p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp
Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp
p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
......@@ -5,7 +5,6 @@ from typing import List
from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
......@@ -33,9 +32,8 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
return data
def load_num_sequence_text(
path: Union[Path, str], loader_type: str = "csv_int"
) -> Dict[str, List[Union[float, int]]]:
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
......@@ -73,6 +71,7 @@ def load_num_sequence_text(
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"')
logging.error(
f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval
# en --> zh 的 语音合成
# 根据Prompt_003_new作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的new_str需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
#!/bin/bash
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python inference.py \
--task_name cross-lingual_clone \
--model_name paddle_checkpoint_dual_mask_enzh \
--uid Prompt_003_new \
--new_str '今天天气很好.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language chinese \
--output_name pred_clone.wav \
--use_pt_vocoder False \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_csmsc \
--am_config download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=cross-lingual_clone \
--model_name=paddle_checkpoint_dual_mask_enzh \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=chinese \
--output_name=pred_clone.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_csmsc \
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
#!/bin/bash
# 纯英文的语音合成
# 样例为根据p299_096对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python inference.py \
--task_name synthesize \
--model_name paddle_checkpoint_en \
--uid p299_096 \
--new_str 'I enjoy my life.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language english \
--output_name pred_gen.wav \
--use_pt_vocoder True \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=synthesize \
--model_name=paddle_checkpoint_en \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--output_name=pred_gen.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
#!/bin/bash
# 纯英文的语音编辑
# 样例为把p243_new对应的原始语音: For that reason cover should not be given.编辑成'for that reason cover is impossible to be given.'对应的语音
# NOTE: 语音编辑任务暂支持句子中1个位置的替换或者插入文本操作
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python inference.py \
--task_name edit \
--model_name paddle_checkpoint_en \
--uid p243_new \
--new_str 'for that reason cover is impossible to be given.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language english \
--output_name pred_edit.wav \
--use_pt_vocoder True \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=edit \
--model_name=paddle_checkpoint_en \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--output_name=pred_edit.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
......@@ -86,7 +86,11 @@ def parse_args():
parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument("--use_pt_vocoder", default=True, help="use pytorch version vocoder or not. [Note: only in english condition.]")
parser.add_argument(
"--use_pt_vocoder",
type=str2bool,
default=True,
help="use pytorch version vocoder or not. [Note: only in english condition.]")
# pre
args = parser.parse_args()
......
#!/bin/bash
rm -rf *.wav
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Wrapper class for the vocoder model trained with parallel_wavegan repo."""
import logging
import os
from pathlib import Path
from typing import Optional
from typing import Union
import yaml
import torch
import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__(
self,
model_file: Union[Path, str],
config_file: Optional[Union[Path, str]] = None,
):
self,
model_file: Union[Path, str],
config_file: Optional[Union[Path, str]]=None, ):
"""Initialize ParallelWaveGANPretrainedVocoder module."""
super().__init__()
try:
......@@ -30,8 +23,7 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
except ImportError:
logging.error(
"`parallel_wavegan` is not installed. "
"Please install via `pip install -U parallel_wavegan`."
)
"Please install via `pip install -U parallel_wavegan`.")
raise
if config_file is None:
dirname = os.path.dirname(str(model_file))
......@@ -59,5 +51,4 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
"""
return self.vocoder.inference(
feats,
normalize_before=self.normalize_before,
).view(-1)
normalize_before=self.normalize_before, ).view(-1)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import logging
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from timer import timer
from sedit_arg_parser import parse_args
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_test_dataset
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.utils import str2bool
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference
from paddlespeech.t2s.modules.normalizer import ZScore
from yacs.config import CfgNode
# new add
import paddle.nn.functional as F
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.exps.syn_utils import get_frontend
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder
from sedit_arg_parser import parse_args
# new add
model_alias = {
# acoustic model
......@@ -58,9 +28,6 @@ model_alias = {
}
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
......@@ -69,17 +36,15 @@ def is_chinese(ch):
def build_vocoder_from_file(
vocoder_config_file = None,
vocoder_file = None,
model = None,
device = "cpu",
):
vocoder_config_file=None,
vocoder_file=None,
model=None,
device="cpu", ):
# Build vocoder
if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder(
vocoder_file, vocoder_config_file
)
vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file,
vocoder_config_file)
return vocoder.to(device)
else:
......@@ -91,7 +56,7 @@ def get_voc_out(mel, target_language="chinese"):
args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..."
# print("current vocoder: ", args.voc)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
......@@ -106,6 +71,7 @@ def get_voc_out(mel, target_language="chinese"):
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav)
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
......@@ -159,11 +125,14 @@ def get_am_inference(args, am_config):
return am, am_inference, am_name, am_dataset, phn_id
def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300):
def evaluate_durations(phns,
target_language="chinese",
fs=24000,
hop_length=300):
args = parse_args()
if target_language == 'english':
args.lang='en'
args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
......@@ -171,12 +140,12 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese':
args.lang='zh'
args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0:
......@@ -186,8 +155,6 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
else:
print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..."
# Init body.
......@@ -197,8 +164,8 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
# print(am_config)
# print("---------------------")
# acoustic model
am, am_inference, am_name, am_dataset,phn_id = get_am_inference(args, am_config)
am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args,
am_config)
torch_phns = phns
vocab_phones = {}
......@@ -206,33 +173,31 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones)
phonemes = [
phn if phn in vocab_phones else "sp" for phn in torch_phns
]
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids
phone_ids_new.append(vocab_size-1)
phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
normalized_mel, d_outs, p_outs, e_outs = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
normalized_mel, d_outs, p_outs, e_outs = am.inference(
phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new
def sentence2phns(sentence, target_language="en"):
args = parse_args()
if target_language == 'en':
args.lang='en'
args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh':
args.lang='zh'
args.phones_dict="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args.lang = 'zh'
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else:
print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args)
merge_sentences = True
get_tone_ids = False
......@@ -246,10 +211,8 @@ def sentence2phns(sentence, target_language="en"):
phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes(
sentence,
merge_sentences=merge_sentences,
print_info=False)
sentence, merge_sentences=merge_sentences, print_info=False)
return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en':
......@@ -270,16 +233,11 @@ def sentence2phns(sentence, target_language="en"):
phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp
phones = [
phn
if (phn in vocab_phones and phn not in punc) else "sp"
phn if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones
]
phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0]
return phones_list[0], input_ids["phone_ids"][0]
else:
print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册