未验证 提交 7b864e8f 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

clean old ernie sat inference scripts (#2316)

上级 d21e03c0
([简体中文](./README_cn.md)|English)
<p align="center">
<img src="./docs/images/PaddleSpeech_logo.png" />
......@@ -535,7 +534,7 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</td>
</tr>
<tr>
<td rowspan="4">Acoustic Model</td>
<td rowspan="5">Acoustic Model</td>
<td>Tacotron2</td>
<td>LJSpeech / CSMSC</td>
<td>
......@@ -563,6 +562,13 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
<a href = "./examples/ljspeech/tts3">fastspeech2-ljspeech</a> / <a href = "./examples/vctk/tts3">fastspeech2-vctk</a> / <a href = "./examples/csmsc/tts3">fastspeech2-csmsc</a> / <a href = "./examples/aishell3/tts3">fastspeech2-aishell3</a> / <a href = "./examples/zh_en_tts/tts3">fastspeech2-zh_en</a>
</td>
</tr>
<tr>
<td>ERNIE-SAT</td>
<td>VCTK / AISHELL-3 / ZH_EN</td>
<td>
<a href = "./examples/vctk/ernie_sat">ERNIE-SAT-vctk</a> / <a href = "./examples/aishell3/ernie_sat">ERNIE-SAT-aishell3</a> / <a href = "./examples/aishell3_vctk/ernie_sat">ERNIE-SAT-zh_en</a>
</td>
</tr>
<tr>
<td rowspan="6">Vocoder</td>
<td >WaveFlow</td>
......
(简体中文|[English](./README.md))
<p align="center">
<img src="./docs/images/PaddleSpeech_logo.png" />
......@@ -530,7 +529,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</td>
</tr>
<tr>
<td rowspan="4">声学模型</td>
<td rowspan="5">声学模型</td>
<td>Tacotron2</td>
<td>LJSpeech / CSMSC</td>
<td>
......@@ -558,6 +557,13 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<a href = "./examples/ljspeech/tts3">fastspeech2-ljspeech</a> / <a href = "./examples/vctk/tts3">fastspeech2-vctk</a> / <a href = "./examples/csmsc/tts3">fastspeech2-csmsc</a> / <a href = "./examples/aishell3/tts3">fastspeech2-aishell3</a> / <a href = "./examples/zh_en_tts/tts3">fastspeech2-zh_en</a>
</td>
</tr>
<tr>
<td>ERNIE-SAT</td>
<td>VCTK / AISHELL-3 / ZH_EN</td>
<td>
<a href = "./examples/vctk/ernie_sat">ERNIE-SAT-vctk</a> / <a href = "./examples/aishell3/ernie_sat">ERNIE-SAT-aishell3</a> / <a href = "./examples/aishell3_vctk/ernie_sat">ERNIE-SAT-zh_en</a>
</td>
</tr>
<tr>
<td rowspan="6">声码器</td>
<td >WaveFlow</td>
......
ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架
ERNIE-SAT 中我们提出了两项创新:
- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
- 采用语言和语音的联合掩码学习实现了语言和语音的对齐
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lOXKJXE-1655380879339)(.meta/framework.png)]
## 使用说明
### 1.安装飞桨与环境依赖
- 本项目的代码基于 Paddle(version>=2.0)
- 本项目开放提供加载 torch 版本的 vocoder 的功能
- torch version>=1.8
- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 1.注册账号,下载 htk
- 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash
以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0");
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
```
- 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。
- 安装其他依赖: **sox, libsndfile**
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```bash
mkdir pretrained_model
cd pretrained_model
tar -zxvf model-ernie-sat-base-en.tar.gz
tar -zxvf model-ernie-sat-base-zh.tar.gz
tar -zxvf model-ernie-sat-base-en_zh.tar.gz
```
### 3.下载
1. 本项目使用 parallel wavegan 作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash
mkdir download
cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的 fastspeech2 模型并将其解压:
```bash
cd download
unzip fastspeech2_conformer_baker_ckpt_0.5.zip
unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
```
3. 本项目使用 HTK 获取输入音频和文本的对齐信息:
- [aligner.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/aligner.zip)
下载上述文件到 tools 文件夹并将其解压:
```bash
cd tools
unzip aligner.zip
```
### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件
- prompt/dev: 基于上述特定音频对应的文本、音素相关文件
```text
prompt_wav
├── p299_096.wav # 样例语音文件1
├── p243_313.wav # 样例语音文件2
└── ...
```
```text
prompt/dev
├── text # 样例语音对应文本
├── wav.scp # 样例语音路径
├── mfa_text # 样例语音对应音素
├── mfa_start # 样例语音中各个音素的开始时间
└── mfa_end # 样例语音中各个音素的结束时间
```
1. `--am` 声学模型格式符合 {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat``--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7. `--model_name` 模型名称
8. `--uid` 特定提示(prompt)语音的 id
9. `--new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. `--prefix` 特定音频对应的文本、音素相关文件的地址
11. `--source_lang` , 源语言
12. `--target_lang` , 目标语言
13. `--output_name` , 合成语音名称
14. `--task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
运行以下脚本即可进行实验
```shell
./run_sedit_en.sh # 语音编辑任务(英文)
./run_gen_en.sh # 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
```
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import os
import sys
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_txt_en(line: str, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd.upper() not in ds):
unk_words.add(wrd.upper())
fwid.write(wrd + ' ')
fwid.write('\n')
#generate pronounciations for unknows words using 'letter to sound'
with open(tmpbase + '_unk.words', 'w') as fwid:
for unk in unk_words:
fwid.write(unk + '\n')
try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except Exception:
print('english2phoneme error!')
sys.exit(1)
#add unknown words to the standard dictionary, generate a tmp dictionary for alignment
fw = open(tmpbase + '.dict', 'w')
with open(dictfile, 'r') as fid:
for line in fid:
fw.write(line)
f = open(tmpbase + '_unk.words', 'r')
lines1 = f.readlines()
f.close()
f = open(tmpbase + '_unk.phons', 'r')
lines2 = f.readlines()
f.close()
for i in range(len(lines1)):
wrd = lines1[i].replace('\n', '')
phons = lines2[i].replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
fw.write(wrd + ' ')
for s in seq:
fw.write(' ' + s)
fw.write('\n')
fw.close()
def prep_mlf(txt: str, tmpbase: str):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
fwid.write('"' + tmpbase + '.lab"\n')
fwid.write('sp\n')
wrds = txt.split()
for wrd in wrds:
fwid.write(wrd.upper() + '\n')
fwid.write('sp\n')
fwid.write('.\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path: str, text: str):
'''
intervals: List[phn, start, end]
'''
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict')
except Exception:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except Exception:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except Exception:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except Exception:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
intervals = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return intervals, word2phns
def alignment_zh(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt_zh(
line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except Exception:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except Exception:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except Exception:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except Exception:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
intervals = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return intervals, word2phns
此差异已折叠。
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
# acoustic model
parser.add_argument(
'--am',
type=str,
default='fastspeech2_csmsc',
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model. Use deault config when it is None.')
parser.add_argument(
'--am_ckpt',
type=str,
default=None,
help='Checkpoint file of acoustic model.')
parser.add_argument(
"--am_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
)
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
# vocoder
parser.add_argument(
'--voc',
type=str,
default='pwgan_aishell3',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc'
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc. Use deault config when it is None.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
"--voc_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument("--uid", type=str, help="uid")
parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument("--prefix", type=str, help="prefix")
parser.add_argument(
"--source_lang", type=str, default="english", help="source language")
parser.add_argument(
"--target_lang", type=str, default="english", help="target language")
parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name")
# pre
args = parser.parse_args()
return args
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
import numpy as np
import paddle
import yaml
from sedit_arg_parser import parse_args
from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data = {}
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v
return data
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
if loader_type == "text_int":
delimiter = " "
dtype = int
elif loader_type == "text_float":
delimiter = " "
dtype = float
elif loader_type == "csv_int":
delimiter = ","
dtype = int
elif loader_type == "csv_float":
delimiter = ","
dtype = float
else:
raise ValueError(f"Not supported loader_type={loader_type}")
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d = read_2column_text(path)
# Using for-loop instead of dict-comprehension for debuggability
retval = {}
for k, v in d.items():
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
print(f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
else:
return False
def get_voc_out(mel):
# vocoder
args = parse_args()
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
with paddle.no_grad():
wav = voc_inference(mel)
return np.squeeze(wav)
def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
args = parse_args()
if target_lang == 'english':
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'chinese':
args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
# Init body.
with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
am_inference, am = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict,
return_am=True)
vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id:
vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids.append(vocab_size - 1)
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
_, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phu_durs_new = pre_d_outs * hop_length / fs
phu_durs_new = phu_durs_new.tolist()[:-1]
return phu_durs_new
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=ernie_sat
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
\ No newline at end of file
p243_new For that reason cover should not be given.
Prompt_003_new This was not the show for me.
p299_096 We are trying to establish a date.
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
#!/bin/bash
set -e
source path.sh
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python local/inference.py \
--task_name=cross-lingual_clone \
--model_name=paddle_checkpoint_dual_mask_enzh \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=chinese \
--output_name=pred_clone.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_csmsc \
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python local/inference_new.py \
--task_name=cross-lingual_clone \
--model_name=paddle_checkpoint_dual_mask_enzh \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=chinese \
--output_name=pred_clone.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_csmsc \
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音合成
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python local/inference.py \
--task_name=synthesize \
--model_name=paddle_checkpoint_en \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_gen.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音合成
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python local/inference_new.py \
--task_name=synthesize \
--model_name=paddle_checkpoint_en \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_gen.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音编辑
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python local/inference.py \
--task_name=edit \
--model_name=paddle_checkpoint_en \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_edit.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音编辑
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python local/inference_new.py \
--task_name=edit \
--model_name=paddle_checkpoint_en \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_edit.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
rm -rf *.wav
./run_sedit_en.sh # 语音编辑任务(英文)
./run_gen_en.sh # 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
#!/bin/bash
rm -rf *.wav
./run_sedit_en_new.sh # 语音编辑任务(英文)
./run_gen_en_new.sh # 个性化语音合成任务(英文)
./run_clone_en_to_zh_new.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
......@@ -11,19 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
import paddle
from paddlespeech.t2s.datasets.batch import batch_sequences
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import get_seg_pos
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import phones_masking
from paddlespeech.t2s.modules.nets_utils import phones_text_masking
......@@ -490,182 +483,3 @@ def vits_single_spk_batch_fn(examples):
"speech": speech
}
return batch
# for ERNIE SAT
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
feats_extract,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
attention_window: int=0,
not_sequence: Collection[str]=(), ):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.seg_emb = seg_emb
self.text_masking = text_masking
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
feats_extract=self.feats_extract,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
seg_emb=self.seg_emb,
text_masking=self.text_masking,
not_sequence=self.not_sequence)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
feats_extract=None,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
pad_value: int=0,
not_sequence: Collection[str]=(),
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
print("feats.shape:", feats.shape)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_mlm_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
seg_emb=seg_emb)
......@@ -13,4 +13,3 @@
# limitations under the License.
from .ernie_sat import *
from .ernie_sat_updater import *
from .mlm import *
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册