未验证 提交 437081f7 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #824 from yt605155624/format

[ernie sat]format ernie sat
ernie-sat/.meta/framework.png

306.0 KB | W: | H:

ernie-sat/.meta/framework.png

139.9 KB | W: | H:

ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
ernie-sat/.meta/framework.png
  • 2-up
  • Swipe
  • Onion skin
ERNIE-SAT是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架
ERNIE-SAT中我们提出了两项创新:
ERNIE-SAT 中我们提出了两项创新:
- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
- 采用语言和语音的联合掩码学习实现了语言和语音的对齐
......@@ -12,14 +12,14 @@ ERNIE-SAT中我们提出了两项创新:
### 1.安装飞桨与环境依赖
- 本项目的代码基于 Paddle(version>=2.0)
- 本项目开放提供加载torch版本的vocoder的功能
- 本项目开放提供加载 torch 版本的 vocoder 的功能
- torch version>=1.8
- 安装htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的htk(例如3.4.1)。同时提供[历史版本htk下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 1.注册账号,下载htk
- 2.解压htk文件,**放入项目根目录的tools文件夹中, 以htk文件夹名称放入**
- 3.**注意**: 如果您下载的是3.4.1或者更高版本,需要进入HTKLib/HRec.c文件中, **修改1626行和1650行**, 即把**以下两行的dur<=0 都修改为 dur<0**,如下所示:
- 1.注册账号,下载 htk
- 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash
以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
......@@ -28,26 +28,23 @@ ERNIE-SAT中我们提出了两项创新:
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
```
- 4.**编译**: 详情参见解压后的htk中的README文件(如果未编译, 则无法正常运行)
- 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
- 安装ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN项目并且安装相关依赖即可。
- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。
- 安装其他依赖: **sox, libsndfile**
### 2.预训练模型
预训练模型ERNIE-SAT的模型如下所示:
预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz)
创建download文件夹,下载上述ERNIE-SAT预训练模型并将其解压:
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```bash
mkdir pretrained_model
cd pretrained_model
......@@ -56,13 +53,12 @@ tar -zxvf model-ernie-sat-base-zh.tar.gz
tar -zxvf model-ernie-sat-base-en_zh.tar.gz
```
### 3.下载
1. 本项目使用parallel wavegan作为声码器(vocoder):
1. 本项目使用 parallel wavegan 作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash
mkdir download
......@@ -70,11 +66,11 @@ cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2. 本项目使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的fastspeech2模型并将其解压
下载上述预训练的 fastspeech2 模型并将其解压
```bash
cd download
......@@ -85,7 +81,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若use_pt_vocoder参数设置为False,则英文场景下使用paddle版本的声码器。
注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件
......@@ -114,19 +110,19 @@ prompt/dev
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的id
8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言
12. ` --target_language` , 目标语言
13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` use_pt_vocoder`,英文场景下是否使用torch版本的vocoder, 默认情况下为True; 设置为False则在英文场景下使用paddle版本vocoder
15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
运行以下脚本即可进行实验
```shell
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
```
#!/usr/bin/env python
""" Usage:
align_english.py wavfile trsfile outwordfile outphonefile
align_english.py wavfile trsfile outwordfile outphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
......@@ -48,7 +47,8 @@ def prep_txt(line, tmpbase, dictfile):
for unk in unk_words:
fwid.write(unk + '\n')
try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + '_unk.phons')
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except:
print('english2phoneme error!')
sys.exit(1)
......@@ -79,7 +79,7 @@ def prep_txt(line, tmpbase, dictfile):
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j+2]
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
......@@ -96,6 +96,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.write('\n')
fw.close()
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
......@@ -108,6 +109,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
......@@ -120,19 +122,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000
pen = (int(lines[i].split()[1])/1000+125)/10000
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
......@@ -151,8 +154,13 @@ def gen_res(tmpbase, outfile1, outfile2):
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
......@@ -160,7 +168,7 @@ def alignment(wav_path, text_string):
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
......@@ -179,14 +187,19 @@ def alignment(wav_path, text_string):
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp')
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null')
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
......@@ -207,15 +220,15 @@ def alignment(wav_path, text_string):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000
pen = (int(splited_line[1])/1000+125)/10000
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line)==5:
current_word = str(index)+'_'+splited_line[-1]
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index+=1
elif len(splited_line)==4:
word2phns[current_word] += ' '+phn
i+=1
return times2,word2phns
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
import multiprocessing as mp
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
......@@ -19,7 +17,10 @@ def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', u'。', u':', u';', u'!', u'?', u'(', u')']:
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
......@@ -43,6 +44,7 @@ def prep_txt(line, tmpbase, dictfile):
fwid.write('\n')
return unk_words
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
......@@ -55,6 +57,7 @@ def prep_mlf(txt, tmpbase):
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
......@@ -67,19 +70,20 @@ def gen_res(tmpbase, outfile1, outfile2):
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (lines[i].split()[0] != lines[i].split()[1]):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0])/1000+125)/10000
pen = (int(lines[i].split()[1])/1000+125)/10000
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0])/1000+125)/10000
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j-1].split()[1])/1000+125)/10000
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
......@@ -99,18 +103,18 @@ def gen_res(tmpbase, outfile1, outfile2):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + '.wav remix -')
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
......@@ -133,14 +137,19 @@ def alignment_zh(wav_path, text_string):
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp')
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null')
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
......@@ -156,23 +165,22 @@ def alignment_zh(wav_path, text_string):
i = 2
times2 = []
word2phns = {}
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0])/1000+125)/10000
pen = (int(splited_line[1])/1000+125)/10000
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line)==5:
current_word = str(index)+'_'+splited_line[-1]
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index+=1
elif len(splited_line)==4:
word2phns[current_word] += ' '+phn
i+=1
return times2,word2phns
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
import math
import paddle
import numpy as np
import math
import paddle
def pad_list(xs, pad_value):
......@@ -28,23 +26,25 @@ def pad_list(xs, pad_value):
"""
n_batch = len(xs)
max_len = max(paddle.shape(x)[0] for x in xs)
pad = paddle.full((n_batch, max_len), pad_value, dtype = xs[0].dtype)
pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype)
for i in range(n_batch):
pad[i, : paddle.shape(xs[i])[0]] = xs[i]
pad[i, :paddle.shape(xs[i])[0]] = xs[i]
return pad
def pad_to_longformer_att_window(text, max_len, max_tlen,attention_window):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros((n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch):
text_pad[i, : paddle.shape(text[i])[0]] = text[i]
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, : max_tlen]
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
......@@ -139,7 +139,6 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if not isinstance(lengths, list):
lengths = list(lengths)
# print('lengths', lengths)
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
......@@ -147,10 +146,9 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = paddle.expand(paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_range_expand = paddle.expand(
paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1)
# print('seq_length_expand', paddle.shape(seq_length_expand))
# print('seq_range_expand', paddle.shape(seq_range_expand))
mask = seq_range_expand >= seq_length_expand
if xs is not None:
......@@ -160,16 +158,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(len(paddle.shape(xs)))
)
# print('0:', paddle.shape(mask))
# print('1:', paddle.shape(mask[ind]))
# print('2:', paddle.shape(xs))
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
......@@ -259,8 +253,14 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
return ~make_pad_mask(lengths, xs, length_dim)
def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths, mlm_prob, mean_phn_span, span_boundary=None):
def phones_masking(xs_pad,
src_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary=None):
bz, sent_len, _ = paddle.shape(xs_pad)
mask_num_lower = math.ceil(sent_len * mlm_prob)
masked_position = np.zeros((bz, sent_len))
......@@ -273,38 +273,41 @@ def phones_masking(xs_pad, src_mask, align_start, align_end, align_start_lengths
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length*mlm_prob//3, 50)
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero()
masked_position[:,masked_phn_indices]=1
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_indices = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_position[:, masked_phn_indices] = 1
else:
for idx in range(bz):
if span_boundary is not None:
for s,e in zip(span_boundary[idx][::2], span_boundary[idx][1::2]):
for s, e in zip(span_boundary[idx][::2],
span_boundary[idx][1::2]):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
else:
length = align_start_lengths[idx].item()
if length<2:
if length < 2:
continue
masked_phn_indices = random_spans_noise_mask(length,mlm_prob, mean_phn_span).nonzero()
masked_phn_indices = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_indices].tolist()
masked_end = align_end[idx][masked_phn_indices].tolist()
for s,e in zip(masked_start, masked_end):
for s, e in zip(masked_start, masked_end):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
non_eos_mask = np.array(paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
non_eos_mask = np.array(
paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
masked_position = masked_position * non_eos_mask
# y_masks = src_mask & y_masks.bool()
return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks
def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_lengths,sega_emb):
def get_segment_pos(speech_pad, text_pad, align_start, align_end,
align_start_lengths, sega_emb):
bz, speech_len, _ = speech_pad.size()
_, text_len = text_pad.size()
......@@ -313,7 +316,6 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
text_segment_pos = np.zeros((bz, text_len)).astype('int64')
speech_segment_pos = np.zeros((bz, speech_len)).astype('int64')
if not sega_emb:
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
......@@ -321,11 +323,11 @@ def get_segment_pos(speech_pad, text_pad, align_start, align_end, align_start_le
for idx in range(bz):
align_length = align_start_lengths[idx].item()
for j in range(align_length):
s,e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j+1
text_segment_pos[idx][j] = j+1
s, e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j + 1
text_segment_pos[idx][j] = j + 1
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos
\ No newline at end of file
return speech_segment_pos, text_segment_pos
此差异已折叠。
此差异已折叠。
p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525
Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125
p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825
p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125
Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625
p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325
p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp
Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp
p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
......@@ -5,7 +5,6 @@ from typing import List
from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
......@@ -33,9 +32,8 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
return data
def load_num_sequence_text(
path: Union[Path, str], loader_type: str = "csv_int"
) -> Dict[str, List[Union[float, int]]]:
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
......@@ -73,6 +71,7 @@ def load_num_sequence_text(
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"')
logging.error(
f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval
# en --> zh 的 语音合成
# 根据Prompt_003_new作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的new_str需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
#!/bin/bash
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python inference.py \
--task_name cross-lingual_clone \
--model_name paddle_checkpoint_dual_mask_enzh \
--uid Prompt_003_new \
--new_str '今天天气很好.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language chinese \
--output_name pred_clone.wav \
--use_pt_vocoder False \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_csmsc \
--am_config download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=cross-lingual_clone \
--model_name=paddle_checkpoint_dual_mask_enzh \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=chinese \
--output_name=pred_clone.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_csmsc \
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
#!/bin/bash
# 纯英文的语音合成
# 样例为根据p299_096对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python inference.py \
--task_name synthesize \
--model_name paddle_checkpoint_en \
--uid p299_096 \
--new_str 'I enjoy my life.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language english \
--output_name pred_gen.wav \
--use_pt_vocoder True \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=synthesize \
--model_name=paddle_checkpoint_en \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--output_name=pred_gen.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
#!/bin/bash
# 纯英文的语音编辑
# 样例为把p243_new对应的原始语音: For that reason cover should not be given.编辑成'for that reason cover is impossible to be given.'对应的语音
# NOTE: 语音编辑任务暂支持句子中1个位置的替换或者插入文本操作
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python inference.py \
--task_name edit \
--model_name paddle_checkpoint_en \
--uid p243_new \
--new_str 'for that reason cover is impossible to be given.' \
--prefix ./prompt/dev/ \
--source_language english \
--target_language english \
--output_name pred_edit.wav \
--use_pt_vocoder True \
--voc pwgan_aishell3 \
--voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am fastspeech2_ljspeech \
--am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
\ No newline at end of file
--task_name=edit \
--model_name=paddle_checkpoint_en \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--output_name=pred_edit.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
......@@ -86,7 +86,11 @@ def parse_args():
parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument("--use_pt_vocoder", default=True, help="use pytorch version vocoder or not. [Note: only in english condition.]")
parser.add_argument(
"--use_pt_vocoder",
type=str2bool,
default=True,
help="use pytorch version vocoder or not. [Note: only in english condition.]")
# pre
args = parser.parse_args()
......
#!/bin/bash
rm -rf *.wav
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Wrapper class for the vocoder model trained with parallel_wavegan repo."""
import logging
import os
from pathlib import Path
from typing import Optional
from typing import Union
import yaml
import torch
import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__(
self,
model_file: Union[Path, str],
config_file: Optional[Union[Path, str]] = None,
):
self,
model_file: Union[Path, str],
config_file: Optional[Union[Path, str]]=None, ):
"""Initialize ParallelWaveGANPretrainedVocoder module."""
super().__init__()
try:
......@@ -30,8 +23,7 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
except ImportError:
logging.error(
"`parallel_wavegan` is not installed. "
"Please install via `pip install -U parallel_wavegan`."
)
"Please install via `pip install -U parallel_wavegan`.")
raise
if config_file is None:
dirname = os.path.dirname(str(model_file))
......@@ -59,5 +51,4 @@ class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
"""
return self.vocoder.inference(
feats,
normalize_before=self.normalize_before,
).view(-1)
normalize_before=self.normalize_before, ).view(-1)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import logging
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from timer import timer
from sedit_arg_parser import parse_args
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_test_dataset
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.utils import str2bool
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference
from paddlespeech.t2s.modules.normalizer import ZScore
from yacs.config import CfgNode
# new add
import paddle.nn.functional as F
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.exps.syn_utils import get_frontend
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder
from sedit_arg_parser import parse_args
# new add
model_alias = {
# acoustic model
......@@ -58,9 +28,6 @@ model_alias = {
}
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
......@@ -69,17 +36,15 @@ def is_chinese(ch):
def build_vocoder_from_file(
vocoder_config_file = None,
vocoder_file = None,
model = None,
device = "cpu",
):
vocoder_config_file=None,
vocoder_file=None,
model=None,
device="cpu", ):
# Build vocoder
if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder(
vocoder_file, vocoder_config_file
)
vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file,
vocoder_config_file)
return vocoder.to(device)
else:
......@@ -91,7 +56,7 @@ def get_voc_out(mel, target_language="chinese"):
args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..."
# print("current vocoder: ", args.voc)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
......@@ -106,6 +71,7 @@ def get_voc_out(mel, target_language="chinese"):
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav)
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
......@@ -159,11 +125,14 @@ def get_am_inference(args, am_config):
return am, am_inference, am_name, am_dataset, phn_id
def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300):
def evaluate_durations(phns,
target_language="chinese",
fs=24000,
hop_length=300):
args = parse_args()
if target_language == 'english':
args.lang='en'
args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
......@@ -171,12 +140,12 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese':
args.lang='zh'
args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0:
......@@ -186,8 +155,6 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
else:
print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..."
# Init body.
......@@ -197,8 +164,8 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
# print(am_config)
# print("---------------------")
# acoustic model
am, am_inference, am_name, am_dataset,phn_id = get_am_inference(args, am_config)
am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args,
am_config)
torch_phns = phns
vocab_phones = {}
......@@ -206,33 +173,31 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones)
phonemes = [
phn if phn in vocab_phones else "sp" for phn in torch_phns
]
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids
phone_ids_new.append(vocab_size-1)
phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
normalized_mel, d_outs, p_outs, e_outs = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
normalized_mel, d_outs, p_outs, e_outs = am.inference(
phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new
def sentence2phns(sentence, target_language="en"):
args = parse_args()
if target_language == 'en':
args.lang='en'
args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh':
args.lang='zh'
args.phones_dict="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args.lang = 'zh'
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else:
print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args)
merge_sentences = True
get_tone_ids = False
......@@ -246,10 +211,8 @@ def sentence2phns(sentence, target_language="en"):
phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes(
sentence,
merge_sentences=merge_sentences,
print_info=False)
sentence, merge_sentences=merge_sentences, print_info=False)
return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en':
......@@ -270,16 +233,11 @@ def sentence2phns(sentence, target_language="en"):
phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp
phones = [
phn
if (phn in vocab_phones and phn not in punc) else "sp"
phn if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones
]
phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0]
return phones_list[0], input_ids["phone_ids"][0]
else:
print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册