Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7b864e8f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b864e8f
编写于
8月 26, 2022
作者:
小湉湉
提交者:
GitHub
8月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean old ernie sat inference scripts (#2316)
上级
d21e03c0
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
16 addition
and
3055 deletion
+16
-3055
README.md
README.md
+8
-2
README_cn.md
README_cn.md
+8
-2
examples/ernie_sat/.meta/framework.png
examples/ernie_sat/.meta/framework.png
+0
-0
examples/ernie_sat/README.md
examples/ernie_sat/README.md
+0
-137
examples/ernie_sat/local/align.py
examples/ernie_sat/local/align.py
+0
-454
examples/ernie_sat/local/inference.py
examples/ernie_sat/local/inference.py
+0
-609
examples/ernie_sat/local/inference_new.py
examples/ernie_sat/local/inference_new.py
+0
-622
examples/ernie_sat/local/sedit_arg_parser.py
examples/ernie_sat/local/sedit_arg_parser.py
+0
-97
examples/ernie_sat/local/utils.py
examples/ernie_sat/local/utils.py
+0
-175
examples/ernie_sat/path.sh
examples/ernie_sat/path.sh
+0
-13
examples/ernie_sat/prompt/dev/text
examples/ernie_sat/prompt/dev/text
+0
-3
examples/ernie_sat/prompt/dev/wav.scp
examples/ernie_sat/prompt/dev/wav.scp
+0
-3
examples/ernie_sat/run_clone_en_to_zh.sh
examples/ernie_sat/run_clone_en_to_zh.sh
+0
-27
examples/ernie_sat/run_clone_en_to_zh_new.sh
examples/ernie_sat/run_clone_en_to_zh_new.sh
+0
-27
examples/ernie_sat/run_gen_en.sh
examples/ernie_sat/run_gen_en.sh
+0
-26
examples/ernie_sat/run_gen_en_new.sh
examples/ernie_sat/run_gen_en_new.sh
+0
-26
examples/ernie_sat/run_sedit_en.sh
examples/ernie_sat/run_sedit_en.sh
+0
-27
examples/ernie_sat/run_sedit_en_new.sh
examples/ernie_sat/run_sedit_en_new.sh
+0
-27
examples/ernie_sat/test_run.sh
examples/ernie_sat/test_run.sh
+0
-6
examples/ernie_sat/test_run_new.sh
examples/ernie_sat/test_run_new.sh
+0
-6
examples/ernie_sat/tools/.gitkeep
examples/ernie_sat/tools/.gitkeep
+0
-0
paddlespeech/t2s/datasets/am_batch_fn.py
paddlespeech/t2s/datasets/am_batch_fn.py
+0
-186
paddlespeech/t2s/models/ernie_sat/__init__.py
paddlespeech/t2s/models/ernie_sat/__init__.py
+0
-1
paddlespeech/t2s/models/ernie_sat/mlm.py
paddlespeech/t2s/models/ernie_sat/mlm.py
+0
-579
未找到文件。
README.md
浏览文件 @
7b864e8f
(
[
简体中文
](
./README_cn.md
)
|English)
<p
align=
"center"
>
<img
src=
"./docs/images/PaddleSpeech_logo.png"
/>
...
...
@@ -535,7 +534,7 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</td>
</tr>
<tr>
<td
rowspan=
"
4
"
>
Acoustic Model
</td>
<td
rowspan=
"
5
"
>
Acoustic Model
</td>
<td>
Tacotron2
</td>
<td>
LJSpeech / CSMSC
</td>
<td>
...
...
@@ -563,6 +562,13 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
<a
href =
"./examples/ljspeech/tts3"
>
fastspeech2-ljspeech
</a>
/
<a
href =
"./examples/vctk/tts3"
>
fastspeech2-vctk
</a>
/
<a
href =
"./examples/csmsc/tts3"
>
fastspeech2-csmsc
</a>
/
<a
href =
"./examples/aishell3/tts3"
>
fastspeech2-aishell3
</a>
/
<a
href =
"./examples/zh_en_tts/tts3"
>
fastspeech2-zh_en
</a>
</td>
</tr>
<tr>
<td>
ERNIE-SAT
</td>
<td>
VCTK / AISHELL-3 / ZH_EN
</td>
<td>
<a
href =
"./examples/vctk/ernie_sat"
>
ERNIE-SAT-vctk
</a>
/
<a
href =
"./examples/aishell3/ernie_sat"
>
ERNIE-SAT-aishell3
</a>
/
<a
href =
"./examples/aishell3_vctk/ernie_sat"
>
ERNIE-SAT-zh_en
</a>
</td>
</tr>
<tr>
<td
rowspan=
"6"
>
Vocoder
</td>
<td
>
WaveFlow
</td>
...
...
README_cn.md
浏览文件 @
7b864e8f
(简体中文|
[
English
](
./README.md
)
)
<p
align=
"center"
>
<img
src=
"./docs/images/PaddleSpeech_logo.png"
/>
...
...
@@ -530,7 +529,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</td>
</tr>
<tr>
<td
rowspan=
"
4
"
>
声学模型
</td>
<td
rowspan=
"
5
"
>
声学模型
</td>
<td>
Tacotron2
</td>
<td>
LJSpeech / CSMSC
</td>
<td>
...
...
@@ -558,6 +557,13 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<a
href =
"./examples/ljspeech/tts3"
>
fastspeech2-ljspeech
</a>
/
<a
href =
"./examples/vctk/tts3"
>
fastspeech2-vctk
</a>
/
<a
href =
"./examples/csmsc/tts3"
>
fastspeech2-csmsc
</a>
/
<a
href =
"./examples/aishell3/tts3"
>
fastspeech2-aishell3
</a>
/
<a
href =
"./examples/zh_en_tts/tts3"
>
fastspeech2-zh_en
</a>
</td>
</tr>
<tr>
<td>
ERNIE-SAT
</td>
<td>
VCTK / AISHELL-3 / ZH_EN
</td>
<td>
<a
href =
"./examples/vctk/ernie_sat"
>
ERNIE-SAT-vctk
</a>
/
<a
href =
"./examples/aishell3/ernie_sat"
>
ERNIE-SAT-aishell3
</a>
/
<a
href =
"./examples/aishell3_vctk/ernie_sat"
>
ERNIE-SAT-zh_en
</a>
</td>
</tr>
<tr>
<td
rowspan=
"6"
>
声码器
</td>
<td
>
WaveFlow
</td>
...
...
examples/ernie_sat/.meta/framework.png
已删除
100644 → 0
浏览文件 @
d21e03c0
139.9 KB
examples/ernie_sat/README.md
已删除
100644 → 0
浏览文件 @
d21e03c0
ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架
ERNIE-SAT 中我们提出了两项创新:
-
在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
-
采用语言和语音的联合掩码学习实现了语言和语音的对齐
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lOXKJXE-1655380879339)(.meta/framework.png)]
## 使用说明
### 1.安装飞桨与环境依赖
-
本项目的代码基于 Paddle(version>=2.0)
-
本项目开放提供加载 torch 版本的 vocoder 的功能
-
torch version>=1.8
-
安装 htk: 在
[
官方地址
](
https://htk.eng.cam.ac.uk/
)
注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供
[
历史版本 htk 下载地址
](
https://htk.eng.cam.ac.uk/ftp/software/
)
- 1.注册账号,下载 htk
- 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash
以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0");
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
```
- 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
-
安装 ParallelWaveGAN: 参见
[
官方地址
](
https://github.com/kan-bayashi/ParallelWaveGAN
)
:按照该官方链接的安装流程,直接在
**项目的根目录下**
git clone ParallelWaveGAN 项目并且安装相关依赖即可。
-
安装其他依赖:
**sox, libsndfile**
等
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
-
[
ERNIE-SAT_ZH
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz
)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```
bash
mkdir
pretrained_model
cd
pretrained_model
tar
-zxvf
model-ernie-sat-base-en.tar.gz
tar
-zxvf
model-ernie-sat-base-zh.tar.gz
tar
-zxvf
model-ernie-sat-base-en_zh.tar.gz
```
### 3.下载
1.
本项目使用 parallel wavegan 作为声码器(vocoder):
-
[
pwg_aishell3_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip
)
创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash
mkdir download
cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2.
本项目使用
[
FastSpeech2
](
https://arxiv.org/abs/2006.04558
)
作为音素(phoneme)的持续时间预测器:
-
[
fastspeech2_conformer_baker_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip
)
中文场景下使用
-
[
fastspeech2_nosil_ljspeech_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip
)
英文场景下使用
下载上述预训练的 fastspeech2 模型并将其解压:
```bash
cd download
unzip fastspeech2_conformer_baker_ckpt_0.5.zip
unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
```
3.
本项目使用 HTK 获取输入音频和文本的对齐信息:
-
[
aligner.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/aligner.zip
)
下载上述文件到 tools 文件夹并将其解压:
```bash
cd tools
unzip aligner.zip
```
### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在
[
该链接
](
https://github.com/kan-bayashi/ParallelWaveGAN
)
中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
-
prompt_wav: 提供的音频文件
-
prompt/dev: 基于上述特定音频对应的文本、音素相关文件
```
text
prompt_wav
├── p299_096.wav # 样例语音文件1
├── p243_313.wav # 样例语音文件2
└── ...
```
```
text
prompt/dev
├── text # 样例语音对应文本
├── wav.scp # 样例语音路径
├── mfa_text # 样例语音对应音素
├── mfa_start # 样例语音中各个音素的开始时间
└── mfa_end # 样例语音中各个音素的结束时间
```
1.
`--am`
声学模型格式符合 {model_name}_{dataset}
2.
`--am_config`
,
`--am_checkpoint`
,
`--am_stat`
和
`--phones_dict`
是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。
3.
`--voc`
声码器(vocoder)格式是否符合 {model_name}_{dataset}
4.
`--voc_config`
,
`--voc_checkpoint`
,
`--voc_stat`
是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5.
`--lang`
对应模型的语言可以是
`zh`
或
`en`
。
6.
`--ngpu`
要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7.
`--model_name`
模型名称
8.
`--uid`
特定提示(prompt)语音的 id
9.
`--new_str`
输入的文本(本次开源暂时先设置特定的文本)
10.
`--prefix`
特定音频对应的文本、音素相关文件的地址
11.
`--source_lang`
, 源语言
12.
`--target_lang`
, 目标语言
13.
`--output_name`
, 合成语音名称
14.
`--task_name`
, 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
运行以下脚本即可进行实验
```
shell
./run_sedit_en.sh
# 语音编辑任务(英文)
./run_gen_en.sh
# 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
```
examples/ernie_sat/local/align.py
已删除
100755 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import
os
import
sys
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
HVITE
=
'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
not
in
ds
):
unk_words
.
add
(
wrd
)
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
return
unk_words
def
prep_txt_en
(
line
:
str
,
tmpbase
,
dictfile
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
.
upper
()
not
in
ds
):
unk_words
.
add
(
wrd
.
upper
())
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
#generate pronounciations for unknows words using 'letter to sound'
with
open
(
tmpbase
+
'_unk.words'
,
'w'
)
as
fwid
:
for
unk
in
unk_words
:
fwid
.
write
(
unk
+
'
\n
'
)
try
:
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'_unk.words'
+
' '
+
tmpbase
+
'_unk.phons'
)
except
Exception
:
print
(
'english2phoneme error!'
)
sys
.
exit
(
1
)
#add unknown words to the standard dictionary, generate a tmp dictionary for alignment
fw
=
open
(
tmpbase
+
'.dict'
,
'w'
)
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
fw
.
write
(
line
)
f
=
open
(
tmpbase
+
'_unk.words'
,
'r'
)
lines1
=
f
.
readlines
()
f
.
close
()
f
=
open
(
tmpbase
+
'_unk.phons'
,
'r'
)
lines2
=
f
.
readlines
()
f
.
close
()
for
i
in
range
(
len
(
lines1
)):
wrd
=
lines1
[
i
].
replace
(
'
\n
'
,
''
)
phons
=
lines2
[
i
].
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
fw
.
write
(
wrd
+
' '
)
for
s
in
seq
:
fw
.
write
(
' '
+
s
)
fw
.
write
(
'
\n
'
)
fw
.
close
()
def
prep_mlf
(
txt
:
str
,
tmpbase
:
str
):
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
fwid
.
write
(
'#!MLF!#
\n
'
)
fwid
.
write
(
'"'
+
tmpbase
+
'.lab"
\n
'
)
fwid
.
write
(
'sp
\n
'
)
wrds
=
txt
.
split
()
for
wrd
in
wrds
:
fwid
.
write
(
wrd
.
upper
()
+
'
\n
'
)
fwid
.
write
(
'sp
\n
'
)
fwid
.
write
(
'.
\n
'
)
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
"/"
)[
-
1
]
def
alignment
(
wav_path
:
str
,
text
:
str
):
'''
intervals: List[phn, start, end]
'''
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt_en
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_EN
+
'/dict'
)
except
Exception
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
Exception
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR_EN
+
'/16000/macros -H '
+
MODEL_DIR_EN
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
Exception
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
splited_line
=
lines
[
i
].
strip
().
split
()
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
word2phns
[
current_word
]
=
phn
index
+=
1
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
intervals
,
word2phns
def
alignment_zh
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
unk_words
=
prep_txt_zh
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
Exception
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
Exception
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR_ZH
+
'/16000/macros -H '
+
MODEL_DIR_ZH
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
MODEL_DIR_ZH
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
Exception
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
splited_line
=
lines
[
i
].
strip
().
split
()
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
word2phns
[
current_word
]
=
phn
index
+=
1
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
intervals
,
word2phns
examples/ernie_sat/local/inference.py
已删除
100644 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
from
typing
import
Dict
from
typing
import
List
import
librosa
import
numpy
as
np
import
paddle
import
soundfile
as
sf
from
align
import
alignment
from
align
import
alignment_zh
from
align
import
words2phns
from
align
import
words2phns_zh
from
paddle
import
nn
from
sedit_arg_parser
import
parse_args
from
utils
import
eval_durs
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
utils
import
load_num_sequence_text
from
utils
import
read_2col_text
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_mlm_collate_fn
from
paddlespeech.t2s.models.ernie_sat.mlm
import
build_model_from_file
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
def
get_wav
(
wav_path
:
str
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
non_autoreg
:
bool
=
True
):
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
use_teacher_forcing
=
non_autoreg
)
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
alt_wav
=
get_voc_out
(
masked_feat
)
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
wav_replaced
=
np
.
concatenate
(
[
wav_org
[:
old_time_bdy
[
0
]],
alt_wav
,
wav_org
[
old_time_bdy
[
1
]:]])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_replaced
}
return
data_dict
def
load_model
(
model_name
:
str
=
"paddle_checkpoint_en"
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
mlm_model
,
conf
=
build_model_from_file
(
config_file
=
config_path
,
model_file
=
model_path
)
return
mlm_model
,
conf
def
read_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
# 获取 uid 对应的文本
mfa_text
=
read_2col_text
(
prefix
+
'/text'
)[
uid
]
# 获取 uid 对应的音频路径
mfa_wav_path
=
read_2col_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
not
os
.
path
.
isabs
(
mfa_wav_path
):
mfa_wav_path
=
prefix
+
mfa_wav_path
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2col_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
mfa_path
+
'start'
,
loader_type
=
'text_float'
)[
uid
]
mfa_end
=
load_num_sequence_text
(
mfa_path
+
'end'
,
loader_type
=
'text_float'
)[
uid
]
mfa_wav_path
=
read_2col_text
(
mfa_path
+
'wav.scp'
)[
uid
]
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
fs
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
align_start
=
np
.
floor
(
fs
*
align_start
/
hop_length
).
astype
(
'int'
)
align_end
=
np
.
floor
(
fs
*
align_end
/
hop_length
).
astype
(
'int'
)
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
span_bdy
=
[
align_end
[
-
1
],
align_end
[
-
1
]]
else
:
span_bdy
=
[
align_start
[
span_to_repl
[
0
]],
align_end
[
span_to_repl
[
1
]
-
1
]
]
return
span_bdy
,
align_start
,
align_end
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]):
dic
=
{}
keys_to_del
=
[]
exist_idx
=
[]
sp_count
=
0
add_sp_count
=
0
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
exist_idx
.
append
(
int
(
idx
))
else
:
keys_to_del
.
append
(
key
)
for
key
in
keys_to_del
:
del
word2phns
[
key
]
cur_id
=
0
for
key
in
tp_word2phns
.
keys
():
if
cur_id
in
exist_idx
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
cur_id
+=
1
add_sp_count
+=
1
idx
,
wrd
=
key
.
split
(
'_'
)
dic
[
str
(
cur_id
)
+
"_"
+
wrd
]
=
tp_word2phns
[
key
]
cur_id
+=
1
if
add_sp_count
+
1
==
sp_count
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
add_sp_count
+=
1
assert
add_sp_count
==
sp_count
,
"sp are not added in dic"
return
dic
def
get_max_idx
(
dic
):
return
sorted
([
int
(
key
.
split
(
'_'
)[
0
])
for
key
in
dic
.
keys
()])[
-
1
]
def
get_phns_and_spans
(
wav_path
:
str
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
is_append
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
# source
if
source_lang
==
"english"
:
intervals
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
==
"chinese"
:
intervals
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
for
key
,
value
in
tp_word2phns
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
cur_val
=
" "
.
join
(
value
)
tp_word2phns
[
key
]
=
cur_val
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
\
"source_lang is wrong..."
for
item
in
intervals
:
old_phns
.
append
(
item
[
0
])
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
# target
if
is_append
and
(
source_lang
!=
target_lang
):
cross_lingual_clone
=
True
else
:
cross_lingual_clone
=
False
if
cross_lingual_clone
:
str_origin
=
new_str
[:
len
(
old_str
)]
str_append
=
new_str
[
len
(
old_str
):]
if
target_lang
==
"chinese"
:
phns_origin
,
origin_word2phns
=
words2phns
(
str_origin
)
phns_append
,
append_word2phns_tmp
=
words2phns_zh
(
str_append
)
elif
target_lang
==
"english"
:
# 原始句子
phns_origin
,
origin_word2phns
=
words2phns_zh
(
str_origin
)
# clone 句子
phns_append
,
append_word2phns_tmp
=
words2phns
(
str_append
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
new_phns
=
phns_origin
+
phns_append
append_word2phns
=
{}
length
=
len
(
origin_word2phns
)
for
key
,
value
in
append_word2phns_tmp
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_word2phns
=
origin_word2phns
.
copy
()
new_word2phns
.
update
(
append_word2phns
)
else
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
new_phns
,
new_word2phns
=
words2phns
(
new_str
)
elif
source_lang
==
target_lang
and
target_lang
==
"chinese"
:
new_phns
,
new_word2phns
=
words2phns_zh
(
new_str
)
else
:
assert
source_lang
==
target_lang
,
\
"source language is not same with target language..."
span_to_repl
=
[
0
,
len
(
old_phns
)
-
1
]
span_to_add
=
[
0
,
len
(
new_phns
)
-
1
]
left_idx
=
0
new_phns_left
=
[]
sp_count
=
0
# find the left different index
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_left
.
append
(
'sp'
)
else
:
idx
=
str
(
int
(
idx
)
-
sp_count
)
if
idx
+
'_'
+
wrd
in
new_word2phns
:
left_idx
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_left
.
extend
(
word2phns
[
key
].
split
())
else
:
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
break
# reverse word2phns and new_word2phns
right_idx
=
0
new_phns_right
=
[]
sp_count
=
0
word2phns_max_idx
=
get_max_idx
(
word2phns
)
new_word2phns_max_idx
=
get_max_idx
(
new_word2phns
)
new_phns_mid
=
[]
if
is_append
:
new_phns_right
=
[]
new_phns_mid
=
new_phns
[
left_idx
:]
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
# speech edit
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_right
=
[
'sp'
]
+
new_phns_right
else
:
idx
=
str
(
new_word2phns_max_idx
-
(
word2phns_max_idx
-
int
(
idx
)
-
sp_count
))
if
idx
+
'_'
+
wrd
in
new_word2phns
:
right_idx
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_right
=
word2phns
[
key
].
split
()
+
new_phns_right
else
:
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_mid
=
new_phns
[
left_idx
:
right_idx
]
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
if
len
(
new_phns_mid
)
==
0
:
span_to_add
[
1
]
=
min
(
span_to_add
[
1
]
+
1
,
len
(
new_phns
))
span_to_add
[
0
]
=
max
(
0
,
span_to_add
[
0
]
-
1
)
span_to_repl
[
0
]
=
max
(
0
,
span_to_repl
[
0
]
-
1
)
span_to_repl
[
1
]
=
min
(
span_to_repl
[
1
]
+
1
,
len
(
old_phns
))
break
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
get_dur_adj_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
factor_list
=
[]
for
orig
,
pred
,
phn
in
zip
(
orig_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
continue
else
:
factor_list
.
append
(
orig
/
pred
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
return
1
length
=
2
avg
=
np
.
average
(
factor_list
[
length
:
-
length
])
return
avg
def
prep_feats_with_dur
(
wav_path
:
str
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
)
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
old_durs
=
eval_durs
(
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
orig_old_durs
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
span_to_add
=
span_to_repl
d_factor_left
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
[:
span_to_repl
[
0
]],
pred_dur
=
old_durs
[:
span_to_repl
[
0
]],
phns
=
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
[
span_to_repl
[
1
]:],
pred_dur
=
old_durs
[
span_to_repl
[
1
]:],
phns
=
old_phns
[
span_to_repl
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
old_durs
]
else
:
if
duration_adjust
:
d_factor
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
new_durs
=
eval_durs
(
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
new_durs
]
new_span_dur_sum
=
sum
(
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_dur_sum
=
sum
(
orig_old_durs
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
dur_offset
=
new_span_dur_sum
-
old_span_dur_sum
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
for
i
in
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
i
+
dur_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
# 在原始句子后拼接
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
left_idx
=
len
(
wav_org
)
right_idx
=
left_idx
# 在原始句子中间替换
else
:
left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_dur_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav
=
np
.
concatenate
(
[
wav_org
[:
left_idx
],
blank_wav
,
wav_org
[
right_idx
:]])
# 4. get old and new mel span to be mask
# [92, 92]
old_span_bdy
,
mfa_start
,
mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
mfa_start
,
mfa_end
=
mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_repl
)
# [92, 174]
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
new_span_bdy
,
new_mfa_start
,
new_mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
new_mfa_start
,
mfa_end
=
new_mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_add
)
# old_span_bdy, new_span_bdy 是帧级别的范围
return
new_wav
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prep_feats
(
wav_path
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
wav
,
phns
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prep_feats_with_dur
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
old_str
=
old_str
,
new_str
=
new_str
,
wav_path
=
wav_path
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
fs
=
fs
,
hop_length
=
hop_length
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns
)))
span_bdy
=
np
.
array
(
new_span_bdy
)
batch
=
[(
'1'
,
{
"speech"
:
wav
,
"align_start"
:
mfa_start
,
"align_end"
:
mfa_end
,
"text"
:
text
,
"span_bdy"
:
span_bdy
})]
return
batch
,
old_span_bdy
,
new_span_bdy
def
decode_with_model
(
mlm_model
:
nn
.
Layer
,
collate_fn
,
wav_path
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
batch
,
old_span_bdy
,
new_span_bdy
=
prep_feats
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
fs
=
fs
,
hop_length
=
hop_length
,
token_list
=
token_list
)
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
output
=
mlm_model
.
inference
(
text
=
feats
[
'text'
],
speech
=
feats
[
'speech'
],
masked_pos
=
feats
[
'masked_pos'
],
speech_mask
=
feats
[
'speech_mask'
],
text_mask
=
feats
[
'text_mask'
],
speech_seg_pos
=
feats
[
'speech_seg_pos'
],
text_seg_pos
=
feats
[
'text_seg_pos'
],
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
# 拼接音频
output_feat
=
paddle
.
concat
(
x
=
output
,
axis
=
0
)
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
return
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
def
get_mlm_output
(
wav_path
:
str
,
model_name
:
str
=
"paddle_checkpoint_en"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
):
mlm_model
,
train_conf
=
load_model
(
model_name
)
mlm_model
.
eval
()
collate_fn
=
build_mlm_collate_fn
(
sr
=
train_conf
.
feats_extract_conf
[
'fs'
],
n_fft
=
train_conf
.
feats_extract_conf
[
'n_fft'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
win_length
=
train_conf
.
feats_extract_conf
[
'win_length'
],
n_mels
=
train_conf
.
feats_extract_conf
[
'n_mels'
],
fmin
=
train_conf
.
feats_extract_conf
[
'fmin'
],
fmax
=
train_conf
.
feats_extract_conf
[
'fmax'
],
mlm_prob
=
train_conf
[
'mlm_prob'
],
mean_phn_span
=
train_conf
[
'mean_phn_span'
],
seg_emb
=
train_conf
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
)
return
decode_with_model
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
collate_fn
=
collate_fn
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
use_teacher_forcing
=
use_teacher_forcing
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
fs
=
train_conf
.
feats_extract_conf
[
'fs'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
token_list
=
train_conf
.
token_list
)
def
evaluate
(
uid
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
prefix
:
os
.
PathLike
=
"./prompt/dev/"
,
model_name
:
str
=
"paddle_checkpoint_en"
,
new_str
:
str
=
""
,
prompt_decoding
:
bool
=
False
,
task_name
:
str
=
None
):
# get origin text and path of origin wav
old_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
new_str
=
new_str
elif
task_name
==
'synthesize'
:
new_str
=
old_str
+
new_str
else
:
new_str
=
old_str
+
' '
.
join
([
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
print
(
'new_str is '
,
new_str
)
results_dict
=
get_wav
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
)
return
results_dict
if
__name__
==
"__main__"
:
# parse config and args
args
=
parse_args
()
data_dict
=
evaluate
(
uid
=
args
.
uid
,
source_lang
=
args
.
source_lang
,
target_lang
=
args
.
target_lang
,
prefix
=
args
.
prefix
,
model_name
=
args
.
model_name
,
new_str
=
args
.
new_str
,
task_name
=
args
.
task_name
)
sf
.
write
(
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
print
(
"finished..."
)
examples/ernie_sat/local/inference_new.py
已删除
100644 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
from
typing
import
Dict
from
typing
import
List
import
librosa
import
numpy
as
np
import
paddle
import
soundfile
as
sf
import
yaml
from
align
import
alignment
from
align
import
alignment_zh
from
align
import
words2phns
from
align
import
words2phns_zh
from
paddle
import
nn
from
sedit_arg_parser
import
parse_args
from
utils
import
eval_durs
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
utils
import
load_num_sequence_text
from
utils
import
read_2col_text
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_mlm_collate_fn
from
paddlespeech.t2s.models.ernie_sat.ernie_sat
import
ErnieSAT
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
def
get_wav
(
wav_path
:
str
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
non_autoreg
:
bool
=
True
):
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
use_teacher_forcing
=
non_autoreg
)
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
alt_wav
=
get_voc_out
(
masked_feat
)
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
wav_replaced
=
np
.
concatenate
(
[
wav_org
[:
old_time_bdy
[
0
]],
alt_wav
,
wav_org
[
old_time_bdy
[
1
]:]])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_replaced
}
return
data_dict
def
load_model
(
model_name
:
str
=
"paddle_checkpoint_en"
):
config_path
=
'./pretrained_model/{}/default.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
with
open
(
config_path
)
as
f
:
conf
=
CfgNode
(
yaml
.
safe_load
(
f
))
token_list
=
list
(
conf
.
token_list
)
vocab_size
=
len
(
token_list
)
odim
=
conf
.
n_mels
mlm_model
=
ErnieSAT
(
idim
=
vocab_size
,
odim
=
odim
,
**
conf
[
"model"
])
state_dict
=
paddle
.
load
(
model_path
)
new_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
new_key
=
"model."
+
key
new_state_dict
[
new_key
]
=
value
mlm_model
.
set_state_dict
(
new_state_dict
)
mlm_model
.
eval
()
return
mlm_model
,
conf
def
read_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
# 获取 uid 对应的文本
mfa_text
=
read_2col_text
(
prefix
+
'/text'
)[
uid
]
# 获取 uid 对应的音频路径
mfa_wav_path
=
read_2col_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
not
os
.
path
.
isabs
(
mfa_wav_path
):
mfa_wav_path
=
prefix
+
mfa_wav_path
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2col_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
mfa_path
+
'start'
,
loader_type
=
'text_float'
)[
uid
]
mfa_end
=
load_num_sequence_text
(
mfa_path
+
'end'
,
loader_type
=
'text_float'
)[
uid
]
mfa_wav_path
=
read_2col_text
(
mfa_path
+
'wav.scp'
)[
uid
]
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
fs
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
align_start
=
np
.
floor
(
fs
*
align_start
/
hop_length
).
astype
(
'int'
)
align_end
=
np
.
floor
(
fs
*
align_end
/
hop_length
).
astype
(
'int'
)
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
span_bdy
=
[
align_end
[
-
1
],
align_end
[
-
1
]]
else
:
span_bdy
=
[
align_start
[
span_to_repl
[
0
]],
align_end
[
span_to_repl
[
1
]
-
1
]
]
return
span_bdy
,
align_start
,
align_end
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]):
dic
=
{}
keys_to_del
=
[]
exist_idx
=
[]
sp_count
=
0
add_sp_count
=
0
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
exist_idx
.
append
(
int
(
idx
))
else
:
keys_to_del
.
append
(
key
)
for
key
in
keys_to_del
:
del
word2phns
[
key
]
cur_id
=
0
for
key
in
tp_word2phns
.
keys
():
if
cur_id
in
exist_idx
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
cur_id
+=
1
add_sp_count
+=
1
idx
,
wrd
=
key
.
split
(
'_'
)
dic
[
str
(
cur_id
)
+
"_"
+
wrd
]
=
tp_word2phns
[
key
]
cur_id
+=
1
if
add_sp_count
+
1
==
sp_count
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
add_sp_count
+=
1
assert
add_sp_count
==
sp_count
,
"sp are not added in dic"
return
dic
def
get_max_idx
(
dic
):
return
sorted
([
int
(
key
.
split
(
'_'
)[
0
])
for
key
in
dic
.
keys
()])[
-
1
]
def
get_phns_and_spans
(
wav_path
:
str
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
is_append
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
# source
if
source_lang
==
"english"
:
intervals
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
==
"chinese"
:
intervals
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
for
key
,
value
in
tp_word2phns
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
cur_val
=
" "
.
join
(
value
)
tp_word2phns
[
key
]
=
cur_val
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
\
"source_lang is wrong..."
for
item
in
intervals
:
old_phns
.
append
(
item
[
0
])
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
# target
if
is_append
and
(
source_lang
!=
target_lang
):
cross_lingual_clone
=
True
else
:
cross_lingual_clone
=
False
if
cross_lingual_clone
:
str_origin
=
new_str
[:
len
(
old_str
)]
str_append
=
new_str
[
len
(
old_str
):]
if
target_lang
==
"chinese"
:
phns_origin
,
origin_word2phns
=
words2phns
(
str_origin
)
phns_append
,
append_word2phns_tmp
=
words2phns_zh
(
str_append
)
elif
target_lang
==
"english"
:
# 原始句子
phns_origin
,
origin_word2phns
=
words2phns_zh
(
str_origin
)
# clone 句子
phns_append
,
append_word2phns_tmp
=
words2phns
(
str_append
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
new_phns
=
phns_origin
+
phns_append
append_word2phns
=
{}
length
=
len
(
origin_word2phns
)
for
key
,
value
in
append_word2phns_tmp
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_word2phns
=
origin_word2phns
.
copy
()
new_word2phns
.
update
(
append_word2phns
)
else
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
new_phns
,
new_word2phns
=
words2phns
(
new_str
)
elif
source_lang
==
target_lang
and
target_lang
==
"chinese"
:
new_phns
,
new_word2phns
=
words2phns_zh
(
new_str
)
else
:
assert
source_lang
==
target_lang
,
\
"source language is not same with target language..."
span_to_repl
=
[
0
,
len
(
old_phns
)
-
1
]
span_to_add
=
[
0
,
len
(
new_phns
)
-
1
]
left_idx
=
0
new_phns_left
=
[]
sp_count
=
0
# find the left different index
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_left
.
append
(
'sp'
)
else
:
idx
=
str
(
int
(
idx
)
-
sp_count
)
if
idx
+
'_'
+
wrd
in
new_word2phns
:
left_idx
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_left
.
extend
(
word2phns
[
key
].
split
())
else
:
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
break
# reverse word2phns and new_word2phns
right_idx
=
0
new_phns_right
=
[]
sp_count
=
0
word2phns_max_idx
=
get_max_idx
(
word2phns
)
new_word2phns_max_idx
=
get_max_idx
(
new_word2phns
)
new_phns_mid
=
[]
if
is_append
:
new_phns_right
=
[]
new_phns_mid
=
new_phns
[
left_idx
:]
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
# speech edit
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_right
=
[
'sp'
]
+
new_phns_right
else
:
idx
=
str
(
new_word2phns_max_idx
-
(
word2phns_max_idx
-
int
(
idx
)
-
sp_count
))
if
idx
+
'_'
+
wrd
in
new_word2phns
:
right_idx
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_right
=
word2phns
[
key
].
split
()
+
new_phns_right
else
:
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_mid
=
new_phns
[
left_idx
:
right_idx
]
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
if
len
(
new_phns_mid
)
==
0
:
span_to_add
[
1
]
=
min
(
span_to_add
[
1
]
+
1
,
len
(
new_phns
))
span_to_add
[
0
]
=
max
(
0
,
span_to_add
[
0
]
-
1
)
span_to_repl
[
0
]
=
max
(
0
,
span_to_repl
[
0
]
-
1
)
span_to_repl
[
1
]
=
min
(
span_to_repl
[
1
]
+
1
,
len
(
old_phns
))
break
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
get_dur_adj_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
factor_list
=
[]
for
orig
,
pred
,
phn
in
zip
(
orig_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
continue
else
:
factor_list
.
append
(
orig
/
pred
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
return
1
length
=
2
avg
=
np
.
average
(
factor_list
[
length
:
-
length
])
return
avg
def
prep_feats_with_dur
(
wav_path
:
str
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
)
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
old_durs
=
eval_durs
(
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
orig_old_durs
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
span_to_add
=
span_to_repl
d_factor_left
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
[:
span_to_repl
[
0
]],
pred_dur
=
old_durs
[:
span_to_repl
[
0
]],
phns
=
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
[
span_to_repl
[
1
]:],
pred_dur
=
old_durs
[
span_to_repl
[
1
]:],
phns
=
old_phns
[
span_to_repl
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
old_durs
]
else
:
if
duration_adjust
:
d_factor
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
new_durs
=
eval_durs
(
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
new_durs
]
new_span_dur_sum
=
sum
(
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_dur_sum
=
sum
(
orig_old_durs
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
dur_offset
=
new_span_dur_sum
-
old_span_dur_sum
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
for
i
in
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
i
+
dur_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
# 在原始句子后拼接
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
left_idx
=
len
(
wav_org
)
right_idx
=
left_idx
# 在原始句子中间替换
else
:
left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_dur_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav
=
np
.
concatenate
(
[
wav_org
[:
left_idx
],
blank_wav
,
wav_org
[
right_idx
:]])
# 4. get old and new mel span to be mask
# [92, 92]
old_span_bdy
,
mfa_start
,
mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
mfa_start
,
mfa_end
=
mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_repl
)
# [92, 174]
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
new_span_bdy
,
new_mfa_start
,
new_mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
new_mfa_start
,
mfa_end
=
new_mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_add
)
# old_span_bdy, new_span_bdy 是帧级别的范围
return
new_wav
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prep_feats
(
wav_path
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
wav
,
phns
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prep_feats_with_dur
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
old_str
=
old_str
,
new_str
=
new_str
,
wav_path
=
wav_path
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
fs
=
fs
,
hop_length
=
hop_length
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns
)))
span_bdy
=
np
.
array
(
new_span_bdy
)
batch
=
[(
'1'
,
{
"speech"
:
wav
,
"align_start"
:
mfa_start
,
"align_end"
:
mfa_end
,
"text"
:
text
,
"span_bdy"
:
span_bdy
})]
return
batch
,
old_span_bdy
,
new_span_bdy
def
decode_with_model
(
mlm_model
:
nn
.
Layer
,
collate_fn
,
wav_path
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
batch
,
old_span_bdy
,
new_span_bdy
=
prep_feats
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
fs
=
fs
,
hop_length
=
hop_length
,
token_list
=
token_list
)
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
output
=
mlm_model
.
inference
(
text
=
feats
[
'text'
],
speech
=
feats
[
'speech'
],
masked_pos
=
feats
[
'masked_pos'
],
speech_mask
=
feats
[
'speech_mask'
],
text_mask
=
feats
[
'text_mask'
],
speech_seg_pos
=
feats
[
'speech_seg_pos'
],
text_seg_pos
=
feats
[
'text_seg_pos'
],
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
# 拼接音频
output_feat
=
paddle
.
concat
(
x
=
output
,
axis
=
0
)
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
return
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
def
get_mlm_output
(
wav_path
:
str
,
model_name
:
str
=
"paddle_checkpoint_en"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
):
mlm_model
,
train_conf
=
load_model
(
model_name
)
collate_fn
=
build_mlm_collate_fn
(
sr
=
train_conf
.
fs
,
n_fft
=
train_conf
.
n_fft
,
hop_length
=
train_conf
.
n_shift
,
win_length
=
train_conf
.
win_length
,
n_mels
=
train_conf
.
n_mels
,
fmin
=
train_conf
.
fmin
,
fmax
=
train_conf
.
fmax
,
mlm_prob
=
train_conf
.
mlm_prob
,
mean_phn_span
=
train_conf
.
mean_phn_span
,
seg_emb
=
train_conf
.
model
[
'enc_input_layer'
]
==
'sega_mlm'
)
return
decode_with_model
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
collate_fn
=
collate_fn
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
use_teacher_forcing
=
use_teacher_forcing
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
fs
=
train_conf
.
fs
,
hop_length
=
train_conf
.
n_shift
,
token_list
=
train_conf
.
token_list
)
def
evaluate
(
uid
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
prefix
:
os
.
PathLike
=
"./prompt/dev/"
,
model_name
:
str
=
"paddle_checkpoint_en"
,
new_str
:
str
=
""
,
prompt_decoding
:
bool
=
False
,
task_name
:
str
=
None
):
# get origin text and path of origin wav
old_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
new_str
=
new_str
elif
task_name
==
'synthesize'
:
new_str
=
old_str
+
new_str
else
:
new_str
=
old_str
+
' '
.
join
([
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
print
(
'new_str is '
,
new_str
)
results_dict
=
get_wav
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
)
return
results_dict
if
__name__
==
"__main__"
:
# parse config and args
args
=
parse_args
()
data_dict
=
evaluate
(
uid
=
args
.
uid
,
source_lang
=
args
.
source_lang
,
target_lang
=
args
.
target_lang
,
prefix
=
args
.
prefix
,
model_name
=
args
.
model_name
,
new_str
=
args
.
new_str
,
task_name
=
args
.
task_name
)
sf
.
write
(
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
print
(
"finished..."
)
examples/ernie_sat/local/sedit_arg_parser.py
已删除
100644 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
def
parse_args
():
# parse args and config and redirect to train_sp
parser
=
argparse
.
ArgumentParser
(
description
=
"Synthesize with acoustic model & vocoder"
)
# acoustic model
parser
.
add_argument
(
'--am'
,
type
=
str
,
default
=
'fastspeech2_csmsc'
,
choices
=
[
'speedyspeech_csmsc'
,
'fastspeech2_csmsc'
,
'fastspeech2_ljspeech'
,
'fastspeech2_aishell3'
,
'fastspeech2_vctk'
,
'tacotron2_csmsc'
,
'tacotron2_ljspeech'
,
'tacotron2_aishell3'
],
help
=
'Choose acoustic model type of tts task.'
)
parser
.
add_argument
(
'--am_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of acoustic model. Use deault config when it is None.'
)
parser
.
add_argument
(
'--am_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file of acoustic model.'
)
parser
.
add_argument
(
"--am_stat"
,
type
=
str
,
default
=
None
,
help
=
"mean and standard deviation used to normalize spectrogram when training acoustic model."
)
parser
.
add_argument
(
"--phones_dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
parser
.
add_argument
(
"--tones_dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
parser
.
add_argument
(
"--speaker_dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file."
)
# vocoder
parser
.
add_argument
(
'--voc'
,
type
=
str
,
default
=
'pwgan_aishell3'
,
choices
=
[
'pwgan_csmsc'
,
'pwgan_ljspeech'
,
'pwgan_aishell3'
,
'pwgan_vctk'
,
'mb_melgan_csmsc'
,
'wavernn_csmsc'
,
'hifigan_csmsc'
,
'hifigan_ljspeech'
,
'hifigan_aishell3'
,
'hifigan_vctk'
,
'style_melgan_csmsc'
],
help
=
'Choose vocoder type of tts task.'
)
parser
.
add_argument
(
'--voc_config'
,
type
=
str
,
default
=
None
,
help
=
'Config of voc. Use deault config when it is None.'
)
parser
.
add_argument
(
'--voc_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file of voc.'
)
parser
.
add_argument
(
"--voc_stat"
,
type
=
str
,
default
=
None
,
help
=
"mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
help
=
"model name"
)
parser
.
add_argument
(
"--uid"
,
type
=
str
,
help
=
"uid"
)
parser
.
add_argument
(
"--new_str"
,
type
=
str
,
help
=
"new string"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
help
=
"prefix"
)
parser
.
add_argument
(
"--source_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"source language"
)
parser
.
add_argument
(
"--target_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"target language"
)
parser
.
add_argument
(
"--output_name"
,
type
=
str
,
help
=
"output name"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
help
=
"task name"
)
# pre
args
=
parser
.
parse_args
()
return
args
examples/ernie_sat/local/utils.py
已删除
100644 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
List
from
typing
import
Union
import
numpy
as
np
import
paddle
import
yaml
from
sedit_arg_parser
import
parse_args
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
def
read_2col_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data
=
{}
with
Path
(
path
).
open
(
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
linenum
,
line
in
enumerate
(
f
,
1
):
sps
=
line
.
rstrip
().
split
(
maxsplit
=
1
)
if
len
(
sps
)
==
1
:
k
,
v
=
sps
[
0
],
""
else
:
k
,
v
=
sps
if
k
in
data
:
raise
RuntimeError
(
f
"
{
k
}
is duplicated (
{
path
}
:
{
linenum
}
)"
)
data
[
k
]
=
v
return
data
def
load_num_sequence_text
(
path
:
Union
[
Path
,
str
],
loader_type
:
str
=
"csv_int"
)
->
Dict
[
str
,
List
[
Union
[
float
,
int
]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
if
loader_type
==
"text_int"
:
delimiter
=
" "
dtype
=
int
elif
loader_type
==
"text_float"
:
delimiter
=
" "
dtype
=
float
elif
loader_type
==
"csv_int"
:
delimiter
=
","
dtype
=
int
elif
loader_type
==
"csv_float"
:
delimiter
=
","
dtype
=
float
else
:
raise
ValueError
(
f
"Not supported loader_type=
{
loader_type
}
"
)
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d
=
read_2column_text
(
path
)
# Using for-loop instead of dict-comprehension for debuggability
retval
=
{}
for
k
,
v
in
d
.
items
():
try
:
retval
[
k
]
=
[
dtype
(
i
)
for
i
in
v
.
split
(
delimiter
)]
except
TypeError
:
print
(
f
'Error happened with path="
{
path
}
", id="
{
k
}
", value="
{
v
}
"'
)
raise
return
retval
def
is_chinese
(
ch
):
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
return
True
else
:
return
False
def
get_voc_out
(
mel
):
# vocoder
args
=
parse_args
()
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_stat
=
args
.
voc_stat
)
with
paddle
.
no_grad
():
wav
=
voc_inference
(
mel
)
return
np
.
squeeze
(
wav
)
def
eval_durs
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
args
=
parse_args
()
if
target_lang
==
'english'
:
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
==
'chinese'
:
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
paddle
.
set_device
(
"gpu"
)
else
:
print
(
"ngpu should >= 0 !"
)
# Init body.
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
am_inference
,
am
=
get_am_inference
(
am
=
args
.
am
,
am_config
=
am_config
,
am_ckpt
=
args
.
am_ckpt
,
am_stat
=
args
.
am_stat
,
phones_dict
=
args
.
phones_dict
,
tones_dict
=
args
.
tones_dict
,
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids
.
append
(
vocab_size
-
1
)
phone_ids
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids
,
np
.
int64
))
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
phu_durs_new
=
pre_d_outs
*
hop_length
/
fs
phu_durs_new
=
phu_durs_new
.
tolist
()[:
-
1
]
return
phu_durs_new
examples/ernie_sat/path.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
export
PYTHONDONTWRITEBYTECODE
=
1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
MODEL
=
ernie_sat
export
BIN_DIR
=
${
MAIN_ROOT
}
/paddlespeech/t2s/exps/
${
MODEL
}
\ No newline at end of file
examples/ernie_sat/prompt/dev/text
已删除
100644 → 0
浏览文件 @
d21e03c0
p243_new For that reason cover should not be given.
Prompt_003_new This was not the show for me.
p299_096 We are trying to establish a date.
examples/ernie_sat/prompt/dev/wav.scp
已删除
100644 → 0
浏览文件 @
d21e03c0
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
examples/ernie_sat/run_clone_en_to_zh.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python
local
/inference.py
\
--task_name
=
cross-lingual_clone
\
--model_name
=
paddle_checkpoint_dual_mask_enzh
\
--uid
=
Prompt_003_new
\
--new_str
=
'今天天气很好.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
chinese
\
--output_name
=
pred_clone.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_csmsc
\
--am_config
=
download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml
\
--am_ckpt
=
download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz
\
--am_stat
=
download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/run_clone_en_to_zh_new.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python
local
/inference_new.py
\
--task_name
=
cross-lingual_clone
\
--model_name
=
paddle_checkpoint_dual_mask_enzh
\
--uid
=
Prompt_003_new
\
--new_str
=
'今天天气很好.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
chinese
\
--output_name
=
pred_clone.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_csmsc
\
--am_config
=
download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml
\
--am_ckpt
=
download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz
\
--am_stat
=
download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/run_gen_en.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# 纯英文的语音合成
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python
local
/inference.py
\
--task_name
=
synthesize
\
--model_name
=
paddle_checkpoint_en
\
--uid
=
p299_096
\
--new_str
=
'I enjoy my life, do you?'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_gen.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_ljspeech
\
--am_config
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml
\
--am_ckpt
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz
\
--am_stat
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/run_gen_en_new.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# 纯英文的语音合成
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python
local
/inference_new.py
\
--task_name
=
synthesize
\
--model_name
=
paddle_checkpoint_en
\
--uid
=
p299_096
\
--new_str
=
'I enjoy my life, do you?'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_gen.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_ljspeech
\
--am_config
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml
\
--am_ckpt
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz
\
--am_stat
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/run_sedit_en.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# 纯英文的语音编辑
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python
local
/inference.py
\
--task_name
=
edit
\
--model_name
=
paddle_checkpoint_en
\
--uid
=
p243_new
\
--new_str
=
'for that reason cover is impossible to be given.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_edit.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_ljspeech
\
--am_config
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml
\
--am_ckpt
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz
\
--am_stat
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/run_sedit_en_new.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
set
-e
source
path.sh
# 纯英文的语音编辑
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python
local
/inference_new.py
\
--task_name
=
edit
\
--model_name
=
paddle_checkpoint_en
\
--uid
=
p243_new
\
--new_str
=
'for that reason cover is impossible to be given.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_edit.wav
\
--voc
=
pwgan_aishell3
\
--voc_config
=
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
=
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
--voc_stat
=
download/pwg_aishell3_ckpt_0.5/feats_stats.npy
\
--am
=
fastspeech2_ljspeech
\
--am_config
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml
\
--am_ckpt
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz
\
--am_stat
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy
\
--phones_dict
=
download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
examples/ernie_sat/test_run.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
rm
-rf
*
.wav
./run_sedit_en.sh
# 语音编辑任务(英文)
./run_gen_en.sh
# 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
examples/ernie_sat/test_run_new.sh
已删除
100755 → 0
浏览文件 @
d21e03c0
#!/bin/bash
rm
-rf
*
.wav
./run_sedit_en_new.sh
# 语音编辑任务(英文)
./run_gen_en_new.sh
# 个性化语音合成任务(英文)
./run_clone_en_to_zh_new.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
examples/ernie_sat/tools/.gitkeep
已删除
100644 → 0
浏览文件 @
d21e03c0
paddlespeech/t2s/datasets/am_batch_fn.py
浏览文件 @
7b864e8f
...
...
@@ -11,19 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
import
numpy
as
np
import
paddle
from
paddlespeech.t2s.datasets.batch
import
batch_sequences
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
get_seg_pos
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
from
paddlespeech.t2s.modules.nets_utils
import
phones_masking
from
paddlespeech.t2s.modules.nets_utils
import
phones_text_masking
...
...
@@ -490,182 +483,3 @@ def vits_single_spk_batch_fn(examples):
"speech"
:
speech
}
return
batch
# for ERNIE SAT
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
,
attention_window
:
int
=
0
,
not_sequence
:
Collection
[
str
]
=
(),
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
seg_emb
=
seg_emb
self
.
text_masking
=
text_masking
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
feats_extract
=
self
.
feats_extract
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
seg_emb
=
self
.
seg_emb
,
text_masking
=
self
.
text_masking
,
not_sequence
=
self
.
not_sequence
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
feats_extract
=
None
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
,
pad_value
:
int
=
0
,
not_sequence
:
Collection
[
str
]
=
(),
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
print
(
"feats.shape:"
,
feats
.
shape
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if
text_masking
:
masked_pos
,
text_masked_pos
=
phones_text_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
text_pad
=
text_pad
,
text_mask
=
text_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else
:
masked_pos
=
phones_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
output_dict
=
{}
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
=
speech_pad
,
text_pad
=
text_pad
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
seg_emb
=
seg_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output
=
(
uttids
,
output_dict
)
return
output
def
build_mlm_collate_fn
(
sr
:
int
=
24000
,
n_fft
:
int
=
2048
,
hop_length
:
int
=
300
,
win_length
:
int
=
None
,
n_mels
:
int
=
80
,
fmin
:
int
=
80
,
fmax
:
int
=
7600
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
seg_emb
:
bool
=
False
,
epoch
:
int
=-
1
,
):
feats_extract_class
=
LogMelFBank
feats_extract
=
feats_extract_class
(
sr
=
sr
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
return
MLMCollateFn
(
feats_extract
=
feats_extract
,
mlm_prob
=
mlm_prob
*
mlm_prob_factor
,
mean_phn_span
=
mean_phn_span
,
seg_emb
=
seg_emb
)
paddlespeech/t2s/models/ernie_sat/__init__.py
浏览文件 @
7b864e8f
...
...
@@ -13,4 +13,3 @@
# limitations under the License.
from
.ernie_sat
import
*
from
.ernie_sat_updater
import
*
from
.mlm
import
*
paddlespeech/t2s/models/ernie_sat/mlm.py
已删除
100644 → 0
浏览文件 @
d21e03c0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
import
paddle
import
yaml
from
paddle
import
nn
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.modules.activation
import
get_activation
from
paddlespeech.t2s.modules.conformer.convolution
import
ConvolutionModule
from
paddlespeech.t2s.modules.conformer.encoder_layer
import
EncoderLayer
from
paddlespeech.t2s.modules.layer_norm
import
LayerNorm
from
paddlespeech.t2s.modules.masked_fill
import
masked_fill
from
paddlespeech.t2s.modules.nets_utils
import
initialize
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.transformer.attention
import
LegacyRelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
MultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.embedding
import
LegacyRelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
PositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
RelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
ScaledPositionalEncoding
from
paddlespeech.t2s.modules.transformer.multi_layer_conv
import
Conv1dLinear
from
paddlespeech.t2s.modules.transformer.multi_layer_conv
import
MultiLayeredConv1d
from
paddlespeech.t2s.modules.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
from
paddlespeech.t2s.modules.transformer.repeat
import
repeat
from
paddlespeech.t2s.modules.transformer.subsampling
import
Conv2dSubsampling
# MLM -> Mask Language Model
class
mySequential
(
nn
.
Sequential
):
def
forward
(
self
,
*
inputs
):
for
module
in
self
.
_sub_layers
.
values
():
if
type
(
inputs
)
==
tuple
:
inputs
=
module
(
*
inputs
)
else
:
inputs
=
module
(
inputs
)
return
inputs
class
MaskInputLayer
(
nn
.
Layer
):
def
__init__
(
self
,
out_features
:
int
)
->
None
:
super
().
__init__
()
self
.
mask_feature
=
paddle
.
create_parameter
(
shape
=
(
1
,
1
,
out_features
),
dtype
=
paddle
.
float32
,
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
=
None
)
->
paddle
.
Tensor
:
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
return
masked_input
class
MLMEncoder
(
nn
.
Layer
):
"""Conformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, paddle.nn.Layer]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
macaron_style (bool): Whether to use macaron style for positionwise layer.
pos_enc_layer_type (str): Encoder positional encoding layer type.
selfattention_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
"""
def
__init__
(
self
,
idim
:
int
,
vocab_size
:
int
=
0
,
pre_speech_layer
:
int
=
0
,
attention_dim
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
:
bool
=
False
,
pos_enc_layer_type
:
str
=
"abs_pos"
,
selfattention_layer_type
:
str
=
"selfattn"
,
activation_type
:
str
=
"swish"
,
use_cnn_module
:
bool
=
False
,
zero_triu
:
bool
=
False
,
cnn_module_kernel
:
int
=
31
,
padding_idx
:
int
=-
1
,
stochastic_depth_rate
:
float
=
0.0
,
text_masking
:
bool
=
False
):
"""Construct an Encoder object."""
super
().
__init__
()
self
.
_output_size
=
attention_dim
self
.
text_masking
=
text_masking
if
self
.
text_masking
:
self
.
text_masking_layer
=
MaskInputLayer
(
attention_dim
)
activation
=
get_activation
(
activation_type
)
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"scaled_abs_pos"
:
pos_enc_class
=
ScaledPositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
assert
selfattention_layer_type
==
"rel_selfattn"
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"legacy_rel_pos"
:
pos_enc_class
=
LegacyRelPositionalEncoding
assert
selfattention_layer_type
==
"legacy_rel_selfattn"
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
self
.
conv_subsampling_factor
=
1
if
input_layer
==
"linear"
:
self
.
embed
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
==
"conv2d"
:
self
.
embed
=
Conv2dSubsampling
(
idim
,
attention_dim
,
dropout_rate
,
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
self
.
conv_subsampling_factor
=
4
elif
input_layer
==
"embed"
:
self
.
embed
=
nn
.
Sequential
(
nn
.
Embedding
(
idim
,
attention_dim
,
padding_idx
=
padding_idx
),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
==
"mlm"
:
self
.
segment_emb
=
None
self
.
speech_embed
=
mySequential
(
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
))
self
.
text_embed
=
nn
.
Sequential
(
nn
.
Embedding
(
vocab_size
,
attention_dim
,
padding_idx
=
padding_idx
),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
==
"sega_mlm"
:
self
.
segment_emb
=
nn
.
Embedding
(
500
,
attention_dim
,
padding_idx
=
padding_idx
)
self
.
speech_embed
=
mySequential
(
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
))
self
.
text_embed
=
nn
.
Sequential
(
nn
.
Embedding
(
vocab_size
,
attention_dim
,
padding_idx
=
padding_idx
),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
isinstance
(
input_layer
,
nn
.
Layer
):
self
.
embed
=
nn
.
Sequential
(
input_layer
,
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
is
None
:
self
.
embed
=
nn
.
Sequential
(
pos_enc_class
(
attention_dim
,
positional_dropout_rate
))
else
:
raise
ValueError
(
"unknown input_layer: "
+
input_layer
)
self
.
normalize_before
=
normalize_before
# self-attention module definition
if
selfattention_layer_type
==
"selfattn"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"legacy_rel_selfattn"
:
assert
pos_enc_layer_type
==
"legacy_rel_pos"
encoder_selfattn_layer
=
LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"rel_selfattn"
:
assert
pos_enc_layer_type
==
"rel_pos"
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
zero_triu
,
)
else
:
raise
ValueError
(
"unknown encoder_attn_layer: "
+
selfattention_layer_type
)
# feed-forward module definition
if
positionwise_layer_type
==
"linear"
:
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
dropout_rate
,
activation
,
)
elif
positionwise_layer_type
==
"conv1d"
:
positionwise_layer
=
MultiLayeredConv1d
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
positionwise_conv_kernel_size
,
dropout_rate
,
)
elif
positionwise_layer_type
==
"conv1d-linear"
:
positionwise_layer
=
Conv1dLinear
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
positionwise_conv_kernel_size
,
dropout_rate
,
)
else
:
raise
NotImplementedError
(
"Support only linear or conv1d."
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
attention_dim
,
cnn_module_kernel
,
activation
)
self
.
encoders
=
repeat
(
num_blocks
,
lambda
lnum
:
EncoderLayer
(
attention_dim
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args
)
if
use_cnn_module
else
None
,
dropout_rate
,
normalize_before
,
concat_after
,
stochastic_depth_rate
*
float
(
1
+
lnum
)
/
num_blocks
,
),
)
self
.
pre_speech_layer
=
pre_speech_layer
self
.
pre_speech_encoders
=
repeat
(
self
.
pre_speech_layer
,
lambda
lnum
:
EncoderLayer
(
attention_dim
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args
)
if
use_cnn_module
else
None
,
dropout_rate
,
normalize_before
,
concat_after
,
stochastic_depth_rate
*
float
(
1
+
lnum
)
/
self
.
pre_speech_layer
,
),
)
if
self
.
normalize_before
:
self
.
after_norm
=
LayerNorm
(
attention_dim
)
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
=
None
,
text_mask
:
paddle
.
Tensor
=
None
,
speech_seg_pos
:
paddle
.
Tensor
=
None
,
text_seg_pos
:
paddle
.
Tensor
=
None
):
"""Encode input sequence.
"""
if
masked_pos
is
not
None
:
speech
=
self
.
speech_embed
(
speech
,
masked_pos
)
else
:
speech
=
self
.
speech_embed
(
speech
)
if
text
is
not
None
:
text
=
self
.
text_embed
(
text
)
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text
=
(
text
[
0
]
+
text_seg_emb
,
text
[
1
])
speech
=
(
speech
[
0
]
+
speech_seg_emb
,
speech
[
1
])
if
self
.
pre_speech_encoders
:
speech
,
_
=
self
.
pre_speech_encoders
(
speech
,
speech_mask
)
if
text
is
not
None
:
xs
=
paddle
.
concat
([
speech
[
0
],
text
[
0
]],
axis
=
1
)
xs_pos_emb
=
paddle
.
concat
([
speech
[
1
],
text
[
1
]],
axis
=
1
)
masks
=
paddle
.
concat
([
speech_mask
,
text_mask
],
axis
=-
1
)
else
:
xs
=
speech
[
0
]
xs_pos_emb
=
speech
[
1
]
masks
=
speech_mask
xs
,
masks
=
self
.
encoders
((
xs
,
xs_pos_emb
),
masks
)
if
isinstance
(
xs
,
tuple
):
xs
=
xs
[
0
]
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
masks
class
MLMDecoder
(
MLMEncoder
):
def
forward
(
self
,
xs
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
):
"""Encode input sequence.
Args:
xs (paddle.Tensor): Input tensor (#batch, time, idim).
masks (paddle.Tensor): Mask tensor (#batch, time).
Returns:
paddle.Tensor: Output tensor (#batch, time, attention_dim).
paddle.Tensor: Mask tensor (#batch, time).
"""
xs
=
self
.
embed
(
xs
)
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
if
isinstance
(
xs
,
tuple
):
xs
=
xs
[
0
]
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
masks
# encoder and decoder is nn.Layer, not str
class
MLM
(
nn
.
Layer
):
def
__init__
(
self
,
odim
:
int
,
encoder
:
nn
.
Layer
,
decoder
:
Optional
[
nn
.
Layer
],
postnet_layers
:
int
=
0
,
postnet_chans
:
int
=
0
,
postnet_filts
:
int
=
0
,
text_masking
:
bool
=
False
):
super
().
__init__
()
self
.
odim
=
odim
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
vocab_size
=
encoder
.
text_embed
[
0
].
_num_embeddings
if
self
.
decoder
is
None
or
not
(
hasattr
(
self
.
decoder
,
'output_layer'
)
and
self
.
decoder
.
output_layer
is
not
None
):
self
.
sfc
=
nn
.
Linear
(
self
.
encoder
.
_output_size
,
odim
)
else
:
self
.
sfc
=
None
if
text_masking
:
self
.
text_sfc
=
nn
.
Linear
(
self
.
encoder
.
text_embed
[
0
].
_embedding_dim
,
self
.
vocab_size
,
weight_attr
=
self
.
encoder
.
text_embed
[
0
].
_weight_attr
)
else
:
self
.
text_sfc
=
None
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
self
.
encoder
.
_output_size
,
odim
=
odim
,
n_layers
=
postnet_layers
,
n_chans
=
postnet_chans
,
n_filts
=
postnet_filts
,
use_batch_norm
=
True
,
dropout_rate
=
0.5
,
))
def
inference
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
,
span_bdy
:
List
[
int
],
use_teacher_forcing
:
bool
=
False
,
)
->
Dict
[
str
,
paddle
.
Tensor
]:
'''
Args:
speech (paddle.Tensor): input speech (1, Tmax, D).
text (paddle.Tensor): input text (1, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
z_cache
=
None
if
use_teacher_forcing
:
before_outs
,
zs
,
*
_
=
self
.
forward
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
zs
is
None
:
zs
=
before_outs
speech
=
speech
.
squeeze
(
0
)
outs
=
[
speech
[:
span_bdy
[
0
]]]
outs
+=
[
zs
[
0
][
span_bdy
[
0
]:
span_bdy
[
1
]]]
outs
+=
[
speech
[
span_bdy
[
1
]:]]
return
outs
return
None
class
MLMEncAsDecoder
(
MLM
):
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out
,
h_masks
=
self
.
encoder
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
(
paddle
.
shape
(
speech_hidden_states
)[
0
],
-
1
,
self
.
odim
))
else
:
before_outs
=
speech_hidden_states
if
self
.
postnet
is
not
None
:
after_outs
=
before_outs
+
paddle
.
transpose
(
self
.
postnet
(
paddle
.
transpose
(
before_outs
,
[
0
,
2
,
1
])),
[
0
,
2
,
1
])
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
None
class
MLMDualMaksing
(
MLM
):
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out
,
h_masks
=
self
.
encoder
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
text_sfc
:
text_hiddent_states
=
zs
[:,
paddle
.
shape
(
speech
)[
1
]:,
:]
text_outs
=
paddle
.
reshape
(
self
.
text_sfc
(
text_hiddent_states
),
(
paddle
.
shape
(
text_hiddent_states
)[
0
],
-
1
,
self
.
vocab_size
))
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
(
paddle
.
shape
(
speech_hidden_states
)[
0
],
-
1
,
self
.
odim
))
else
:
before_outs
=
speech_hidden_states
if
self
.
postnet
is
not
None
:
after_outs
=
before_outs
+
paddle
.
transpose
(
self
.
postnet
(
paddle
.
transpose
(
before_outs
,
[
0
,
2
,
1
])),
[
0
,
2
,
1
])
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
text_outs
def
build_model_from_file
(
config_file
,
model_file
):
state_dict
=
paddle
.
load
(
model_file
)
model_class
=
MLMDualMaksing
if
'conformer_combine_vctk_aishell3_dual_masking'
in
config_file
\
else
MLMEncAsDecoder
# 构建模型
with
open
(
config_file
)
as
f
:
conf
=
CfgNode
(
yaml
.
safe_load
(
f
))
model
=
build_model
(
conf
,
model_class
)
model
.
set_state_dict
(
state_dict
)
return
model
,
conf
# select encoder and decoder here
def
build_model
(
args
:
argparse
.
Namespace
,
model_class
=
MLMEncAsDecoder
)
->
MLM
:
if
isinstance
(
args
.
token_list
,
str
):
with
open
(
args
.
token_list
,
encoding
=
"utf-8"
)
as
f
:
token_list
=
[
line
.
rstrip
()
for
line
in
f
]
# Overwriting token_list to keep it as "portable".
args
.
token_list
=
list
(
token_list
)
elif
isinstance
(
args
.
token_list
,
(
tuple
,
list
)):
token_list
=
list
(
args
.
token_list
)
else
:
raise
RuntimeError
(
"token_list must be str or list"
)
vocab_size
=
len
(
token_list
)
odim
=
80
# Encoder
encoder_class
=
MLMEncoder
if
'text_masking'
in
args
.
model_conf
.
keys
()
and
args
.
model_conf
[
'text_masking'
]:
args
.
encoder_conf
[
'text_masking'
]
=
True
else
:
args
.
encoder_conf
[
'text_masking'
]
=
False
encoder
=
encoder_class
(
args
.
input_size
,
vocab_size
=
vocab_size
,
**
args
.
encoder_conf
)
# Decoder
if
args
.
decoder
!=
'no_decoder'
:
decoder_class
=
MLMDecoder
decoder
=
decoder_class
(
idim
=
0
,
input_layer
=
None
,
**
args
.
decoder_conf
,
)
else
:
decoder
=
None
# Build model
model
=
model_class
(
odim
=
odim
,
encoder
=
encoder
,
decoder
=
decoder
,
**
args
.
model_conf
,
)
# Initialize
if
args
.
init
is
not
None
:
initialize
(
model
,
args
.
init
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录