Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
9224659c
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9224659c
编写于
6月 15, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add docstring
上级
76b654cb
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
1143 addition
and
1275 deletion
+1143
-1275
ernie-sat/README.md
ernie-sat/README.md
+4
-5
ernie-sat/align.py
ernie-sat/align.py
+160
-24
ernie-sat/collect_fn.py
ernie-sat/collect_fn.py
+217
-0
ernie-sat/dataset.py
ernie-sat/dataset.py
+147
-134
ernie-sat/inference.py
ernie-sat/inference.py
+242
-625
ernie-sat/mlm.py
ernie-sat/mlm.py
+148
-462
ernie-sat/mlm_loss.py
ernie-sat/mlm_loss.py
+53
-0
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
+96
-0
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
+60
-0
ernie-sat/read_text.py
ernie-sat/read_text.py
+2
-2
ernie-sat/sedit_arg_parser.py
ernie-sat/sedit_arg_parser.py
+0
-6
ernie-sat/utils.py
ernie-sat/utils.py
+14
-17
未找到文件。
ernie-sat/README.md
浏览文件 @
9224659c
...
@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
...
@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
预训练模型 ERNIE-SAT 的模型如下所示:
-
[
ERNIE-SAT_ZH
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_ZH
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en_zh.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en_zh.tar.gz
)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
...
@@ -108,7 +108,7 @@ prompt/dev
...
@@ -108,7 +108,7 @@ prompt/dev
3.
`--voc`
声码器(vocoder)格式是否符合 {model_name}_{dataset}
3.
`--voc`
声码器(vocoder)格式是否符合 {model_name}_{dataset}
4.
`--voc_config`
,
`--voc_checkpoint`
,
`--voc_stat`
是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
4.
`--voc_config`
,
`--voc_checkpoint`
,
`--voc_stat`
是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5.
`--lang`
对应模型的语言可以是
`zh`
或
`en`
。
5.
`--lang`
对应模型的语言可以是
`zh`
或
`en`
。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
7.
` --model_name`
模型名称
7.
` --model_name`
模型名称
8.
` --uid`
特定提示(prompt)语音的 id
8.
` --uid`
特定提示(prompt)语音的 id
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
...
@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
...
@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh
# 个性化语音合成任务(英文)
sh run_gen_en.sh
# 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
sh run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
```
```
ernie-sat/align.py
浏览文件 @
9224659c
#!/usr/bin/env python
""" Usage:
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
align.py wavfile trsfile outwordfile outphonefile
"""
"""
import
multiprocessing
as
mp
import
os
import
os
import
sys
import
sys
from
tqdm
import
tqdm
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
...
@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
...
@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
words
=
[]
...
@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
...
@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try
:
try
:
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'_unk.words'
+
' '
+
tmpbase
+
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'_unk.words'
+
' '
+
tmpbase
+
'_unk.phons'
)
'_unk.phons'
)
except
:
except
Exception
:
print
(
'english2phoneme error!'
)
print
(
'english2phoneme error!'
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
...
@@ -148,19 +280,22 @@ def _get_user():
...
@@ -148,19 +280,22 @@ def _get_user():
def
alignment
(
wav_path
:
str
,
text
:
str
):
def
alignment
(
wav_path
:
str
,
text
:
str
):
'''
intervals: List[phn, start, end]
'''
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
#prepare wav and trs files
try
:
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
print
(
'sox error!'
)
return
None
return
None
#prepare clean_transcript file
#prepare clean_transcript file
try
:
try
:
prep_txt_en
(
text
,
tmpbase
,
MODEL_DIR_EN
+
'/dict'
)
prep_txt_en
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_EN
+
'/dict'
)
except
:
except
Exception
:
print
(
'prep_txt error!'
)
print
(
'prep_txt error!'
)
return
None
return
None
...
@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
...
@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
print
(
'prep_mlf error!'
)
return
None
return
None
...
@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
...
@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try
:
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
print
(
'HCopy error!'
)
return
None
return
None
...
@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
...
@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
print
(
'HVite error!'
)
return
None
return
None
...
@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
...
@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
lines
=
fid
.
readlines
()
i
=
2
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
word2phns
=
{}
current_word
=
''
current_word
=
''
index
=
0
index
=
0
...
@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
...
@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn
=
splited_line
[
2
]
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
...
@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif
len
(
splited_line
)
==
4
:
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
def
alignment_zh
(
wav_path
,
text_string
):
def
alignment_zh
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
#prepare wav and trs files
...
@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
...
@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
print
(
'sox error!'
)
return
None
return
None
#prepare clean_transcript file
#prepare clean_transcript file
try
:
try
:
unk_words
=
prep_txt_zh
(
text_string
,
tmpbase
,
MODEL_DIR_ZH
+
'/dict'
)
unk_words
=
prep_txt_zh
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
print
(
"非法words: "
,
unk
)
except
:
except
Exception
:
print
(
'prep_txt error!'
)
print
(
'prep_txt error!'
)
return
None
return
None
...
@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
...
@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
print
(
'prep_mlf error!'
)
return
None
return
None
...
@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
...
@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try
:
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_ZH
+
'/16000/config '
+
tmpbase
+
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
print
(
'HCopy error!'
)
return
None
return
None
...
@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
...
@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
print
(
'HVite error!'
)
return
None
return
None
...
@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
...
@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines
=
fid
.
readlines
()
lines
=
fid
.
readlines
()
i
=
2
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
word2phns
=
{}
current_word
=
''
current_word
=
''
index
=
0
index
=
0
...
@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
...
@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn
=
splited_line
[
2
]
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
...
@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif
len
(
splited_line
)
==
4
:
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
ernie-sat/collect_fn.py
0 → 100644
浏览文件 @
9224659c
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Union
import
numpy
as
np
import
paddle
from
dataset
import
get_seg_pos
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
seg_emb
=
seg_emb
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
seg_emb
=
self
.
seg_emb
,
text_masking
=
self
.
text_masking
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if
text_masking
:
masked_pos
,
text_masked_pos
=
phones_text_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
text_pad
=
text_pad
,
text_mask
=
text_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else
:
masked_pos
=
phones_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
output_dict
=
{}
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
=
speech_pad
,
text_pad
=
text_pad
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
seg_emb
=
seg_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output
=
(
uttids
,
output_dict
)
return
output
def
build_collate_fn
(
sr
:
int
=
24000
,
n_fft
:
int
=
2048
,
hop_length
:
int
=
300
,
win_length
:
int
=
None
,
n_mels
:
int
=
80
,
fmin
:
int
=
80
,
fmax
:
int
=
7600
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
train
:
bool
=
False
,
seg_emb
:
bool
=
False
,
epoch
:
int
=-
1
,
):
feats_extract_class
=
LogMelFBank
feats_extract
=
feats_extract_class
(
sr
=
sr
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
pad_speech
=
False
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
return
MLMCollateFn
(
feats_extract
=
feats_extract
,
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
mlm_prob
*
mlm_prob_factor
,
mean_phn_span
=
mean_phn_span
,
pad_speech
=
pad_speech
,
seg_emb
=
seg_emb
)
ernie-sat/dataset.py
浏览文件 @
9224659c
...
@@ -4,6 +4,68 @@ import numpy as np
...
@@ -4,6 +4,68 @@ import numpy as np
import
paddle
import
paddle
# mask phones
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
# mask speech and phones
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
...
@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
...
@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
float
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
if
mlm_prob
==
1.0
:
masked_pos
+=
1
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
elif
mean_phn_span
==
0
:
# only speech
# only speech
length
=
sent_len
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
masked_phn_idxs
=
random_spans_noise_mask
(
mean_phn_span
).
nonzero
()
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
else
:
for
idx
in
range
(
bz
):
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
else
:
length
=
align_start_lens
[
idx
]
length
=
align_start_lens
[
idx
]
if
length
<
2
:
if
length
<
2
:
continue
continue
masked_phn_idxs
=
random_spans_noise_mask
(
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
unmasked_phn_idxs
=
list
(
unmasked_phn_idxs
=
list
(
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
...
@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
...
@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
return
masked_pos
,
text_masked_pos
,
y_masks
return
masked_pos
,
text_masked_pos
def
get_seg_pos_reduce_duration
(
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
,
seg_emb
:
bool
=
False
):
masked_pos
:
paddle
.
Tensor
,
'''
feats_lens
:
paddle
.
Tensor
,
):
Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
text_seg_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
reordered_idx
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
durations
=
paddle
.
ones
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
if
not
seg_emb
:
max_reduced_length
=
0
return
speech_seg_pos
,
text_seg_pos
if
not
sega_emb
:
return
speech_pad
,
masked_pos
,
speech_seg_pos
,
text_seg_pos
,
durations
for
idx
in
range
(
bz
):
for
idx
in
range
(
bz
):
first_idx
=
[]
last_idx
=
[]
align_length
=
align_start_lens
[
idx
]
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
if
j
==
0
:
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
if
paddle
.
sum
(
masked_pos
[
idx
][
0
:
s
])
==
0
:
text_seg_pos
[
idx
,
j
]
=
j
+
1
first_idx
.
extend
(
range
(
0
,
s
))
else
:
first_idx
.
extend
([
0
])
last_idx
.
extend
(
range
(
1
,
s
))
if
paddle
.
sum
(
masked_pos
[
idx
][
s
:
e
])
==
0
:
first_idx
.
extend
(
range
(
s
,
e
))
else
:
first_idx
.
extend
([
s
])
last_idx
.
extend
(
range
(
s
+
1
,
e
))
durations
[
idx
][
s
]
=
e
-
s
speech_seg_pos
[
idx
][
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
][
j
]
=
j
+
1
max_reduced_length
=
max
(
len
(
first_idx
)
+
feats_lens
[
idx
]
-
e
,
max_reduced_length
)
first_idx
.
extend
(
range
(
e
,
speech_len
))
reordered_idx
[
idx
]
=
paddle
.
to_tensor
(
(
first_idx
+
last_idx
),
dtype
=
align_start_lens
.
dtype
)
feats_lens
[
idx
]
=
len
(
first_idx
)
reordered_idx
=
reordered_idx
[:,
:
max_reduced_length
]
return
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_len
s
return
speech_seg_pos
,
text_seg_po
s
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
,
mean_phn_span
:
float
):
# randomly select the range of speech and text to mask during training
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
float
=
8
):
"""This function is copy of `random_spans_helper
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
Noise mask consisting of random spans of noise tokens.
...
@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
...
@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
mean_noise_span_length: a number
Returns:
Returns:
a boolean tensor with shape [length]
np.ndarray:
a boolean tensor with shape [length]
"""
"""
orig_length
=
length
orig_length
=
length
...
@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
...
@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
return
is_noise
[:
orig_length
]
return
is_noise
[:
orig_length
]
def
pad_to_longformer_att_window
(
text
:
paddle
.
Tensor
,
max_len
:
int
,
max_tlen
:
int
,
attention_window
:
int
=
0
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
(
(
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
int
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
,
y_masks
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
if
not
sega_emb
:
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
speech_seg_pos
,
text_seg_pos
ernie-sat/inference.py
浏览文件 @
9224659c
#!/usr/bin/env python3
#!/usr/bin/env python3
import
argparse
import
os
import
os
import
random
import
random
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Union
import
librosa
import
librosa
import
numpy
as
np
import
numpy
as
np
...
@@ -15,60 +11,42 @@ import paddle
...
@@ -15,60 +11,42 @@ import paddle
import
soundfile
as
sf
import
soundfile
as
sf
import
torch
import
torch
from
paddle
import
nn
from
paddle
import
nn
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
align
import
alignment
from
align
import
alignment_zh
from
dataset
import
get_seg_pos
from
dataset
import
get_seg_pos_reduce_duration
from
dataset
import
pad_to_longformer_att_window
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
model_paddle
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2column_text
from
sedit_arg_parser
import
parse_args
from
sedit_arg_parser
import
parse_args
from
utils
import
build_vocoder_from_file
from
utils
import
build_vocoder_from_file
from
utils
import
evaluate_durations
from
utils
import
evaluate_durations
from
utils
import
get_voc_out
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
utils
import
is_chinese
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
from
align
import
alignment
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
align
import
alignment_zh
from
align
import
words2phns
from
align
import
words2phns_zh
from
collect_fn
import
build_collate_fn
from
mlm
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
random
.
seed
(
0
)
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
def
plot_mel_and_vocode_wav
(
uid
:
str
,
def
plot_mel_and_vocode_wav
(
wav_path
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
'english'
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"conformer"
,
model_name
:
str
=
"paddle_checkpoint_en"
,
full_origin_str
:
str
=
""
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
use_pt_vocoder
:
bool
=
False
,
use_pt_vocoder
:
bool
=
False
,
sid
:
str
=
None
,
non_autoreg
:
bool
=
True
):
non_autoreg
:
bool
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
model_name
=
model_name
,
wav_path
=
wav_path
,
wav_path
=
wav_path
,
old_str
=
old_str
,
old_str
=
old_str
,
new_str
=
new_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
use_teacher_forcing
=
non_autoreg
)
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
)
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
...
@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str,
...
@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str,
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
else
:
else
:
replaced_wav
=
get_voc_out
(
output_feat
,
target_lang
)
replaced_wav
=
get_voc_out
(
output_feat
)
elif
target_lang
==
'chinese'
:
elif
target_lang
==
'chinese'
:
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_lang
)
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
)
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
new_time_bdy
=
[
hop_length
*
x
for
x
in
new_span_bdy
]
new_time_bdy
=
[
hop_length
*
x
for
x
in
new_span_bdy
]
...
@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str,
...
@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str,
return
data_dict
,
old_span_bdy
return
data_dict
,
old_span_bdy
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
...
@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
...
@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
return
vocoder
return
vocoder
def
load_model
(
model_name
:
str
):
def
load_model
(
model_name
:
str
=
"paddle_checkpoint_en"
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
mlm_model
,
args
=
build_model_from_file
(
mlm_model
,
conf
=
build_model_from_file
(
config_file
=
config_path
,
model_file
=
model_path
)
config_file
=
config_path
,
model_file
=
model_path
)
return
mlm_model
,
args
return
mlm_model
,
conf
def
read_data
(
uid
:
str
,
prefix
:
str
):
def
read_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
mfa_text
=
read_2column_text
(
prefix
+
'/text'
)[
uid
]
# 获取 uid 对应的文本
mfa_wav_path
=
read_2column_text
(
prefix
+
'/wav.scp'
)[
uid
]
mfa_text
=
read_2col_text
(
prefix
+
'/text'
)[
uid
]
if
'mnt'
not
in
mfa_wav_path
:
# 获取 uid 对应的音频路径
mfa_wav_path
=
prefix
.
split
(
'dump'
)[
0
]
+
mfa_wav_path
mfa_wav_path
=
read_2col_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
not
os
.
path
.
isabs
(
mfa_wav_path
):
mfa_wav_path
=
prefix
+
mfa_wav_path
return
mfa_text
,
mfa_wav_path
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
:
str
,
prefix
:
str
):
def
get_align_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
mfa_path
=
prefix
+
"mfa_"
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2col
umn
_text
(
mfa_path
+
'text'
)[
uid
]
mfa_text
=
read_2col_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
mfa_start
=
load_num_sequence_text
(
mfa_path
+
'start'
,
loader_type
=
'text_float'
)[
uid
]
mfa_path
+
'start'
,
loader_type
=
'text_float'
)[
uid
]
mfa_end
=
load_num_sequence_text
(
mfa_end
=
load_num_sequence_text
(
mfa_path
+
'end'
,
loader_type
=
'text_float'
)[
uid
]
mfa_path
+
'end'
,
loader_type
=
'text_float'
)[
uid
]
mfa_wav_path
=
read_2col
umn
_text
(
mfa_path
+
'wav.scp'
)[
uid
]
mfa_wav_path
=
read_2col_text
(
mfa_path
+
'wav.scp'
)[
uid
]
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
mfa_end
:
List
[
float
],
fs
:
int
,
fs
:
int
,
hop_length
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_start
=
np
.
array
(
mfa_start
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
align_end
=
np
.
array
(
mfa_end
)
align_start
=
paddle
.
floor
(
fs
*
align_start
/
hop_length
).
int
(
)
align_start
=
np
.
floor
(
fs
*
align_start
/
hop_length
).
astype
(
'int'
)
align_end
=
paddle
.
floor
(
fs
*
align_end
/
hop_length
).
int
(
)
align_end
=
np
.
floor
(
fs
*
align_end
/
hop_length
).
astype
(
'int'
)
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
span_bdy
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()
[
-
1
]]
span_bdy
=
[
align_end
[
-
1
],
align_end
[
-
1
]]
else
:
else
:
span_bdy
=
[
span_bdy
=
[
align_start
[
0
].
tolist
()[
span_to_repl
[
0
]],
align_start
[
span_to_repl
[
0
]],
align_end
[
span_to_repl
[
1
]
-
1
]
align_end
[
0
].
tolist
()[
span_to_repl
[
1
]
-
1
]
]
]
return
span_bdy
return
span_bdy
,
align_start
,
align_end
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]):
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]):
...
@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
...
@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
return
dic
return
dic
def
get_max_idx
(
dic
):
return
sorted
([
int
(
key
.
split
(
'_'
)[
0
])
for
key
in
dic
.
keys
()])[
-
1
]
def
get_phns_and_spans
(
wav_path
:
str
,
def
get_phns_and_spans
(
wav_path
:
str
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
target_lang
:
str
=
"english"
):
append_new_str
=
(
old_str
==
new_str
[:
len
(
old_str
)])
is_append
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
# source
if
source_lang
==
"english"
:
if
source_lang
==
"english"
:
times2
,
word2phns
=
alignment
(
wav_path
,
old_str
)
intervals
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
==
"chinese"
:
elif
source_lang
==
"chinese"
:
times2
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
intervals
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
for
key
,
value
in
tp_word2phns
.
items
():
for
key
,
value
in
tp_word2phns
.
items
():
...
@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str,
...
@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str,
tp_word2phns
[
key
]
=
cur_val
tp_word2phns
[
key
]
=
cur_val
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
else
:
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
"source_lang is wrong..."
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
\
"source_lang is wrong..."
for
item
in
times2
:
for
item
in
intervals
:
old_phns
.
append
(
item
[
0
])
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
mfa_end
.
append
(
float
(
item
[
2
]))
old_phns
.
append
(
item
[
0
])
# target
if
is_append
and
(
source_lang
!=
target_lang
):
if
append_new_str
and
(
source_lang
!=
target_lang
):
cross_lingual_clone
=
True
is_cross_lingual_clone
=
True
else
:
else
:
is_
cross_lingual_clone
=
False
cross_lingual_clone
=
False
if
is_
cross_lingual_clone
:
if
cross_lingual_clone
:
new_
str_origin
=
new_str
[:
len
(
old_str
)]
str_origin
=
new_str
[:
len
(
old_str
)]
new_
str_append
=
new_str
[
len
(
old_str
):]
str_append
=
new_str
[
len
(
old_str
):]
if
target_lang
==
"chinese"
:
if
target_lang
==
"chinese"
:
new_phns_origin
,
new_origin_word2phns
=
words2phns
(
new_str_origin
)
phns_origin
,
origin_word2phns
=
words2phns
(
str_origin
)
new_phns_append
,
temp_new_append_word2phns
=
words2phns_zh
(
phns_append
,
append_word2phns_tmp
=
words2phns_zh
(
str_append
)
new_str_append
)
elif
target_lang
==
"english"
:
elif
target_lang
==
"english"
:
# 原始句子
# 原始句子
new_phns_origin
,
new_origin_word2phns
=
words2phns_zh
(
phns_origin
,
origin_word2phns
=
words2phns_zh
(
str_origin
)
new_str_origin
)
# clone 句子
# clone句子
phns_append
,
append_word2phns_tmp
=
words2phns
(
str_append
)
new_phns_append
,
temp_new_append_word2phns
=
words2phns
(
new_str_append
)
else
:
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
"cloning is not support for this language, please check it."
new_phns
=
new_phns_origin
+
new_
phns_append
new_phns
=
phns_origin
+
phns_append
new_
append_word2phns
=
{}
append_word2phns
=
{}
length
=
len
(
new_
origin_word2phns
)
length
=
len
(
origin_word2phns
)
for
key
,
value
in
temp_new_append_word2phns
.
items
():
for
key
,
value
in
append_word2phns_tmp
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
idx
,
wrd
=
key
.
split
(
'_'
)
new_append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_word2phns
=
origin_word2phns
.
copy
()
new_word2phns
=
dict
(
new_word2phns
.
update
(
append_word2phns
)
list
(
new_origin_word2phns
.
items
())
+
list
(
new_append_word2phns
.
items
()))
else
:
else
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
...
@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str,
...
@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str,
right_idx
=
0
right_idx
=
0
new_phns_right
=
[]
new_phns_right
=
[]
sp_count
=
0
sp_count
=
0
word2phns_max_idx
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
]
)
word2phns_max_idx
=
get_max_idx
(
word2phns
)
new_word2phns_max_idx
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
]
)
new_word2phns_max_idx
=
get_max_idx
(
new_word2phns
)
new_phns_mid
=
[]
new_phns_mid
=
[]
if
append_new_str
:
if
is_append
:
new_phns_right
=
[]
new_phns_right
=
[]
new_phns_mid
=
new_phns
[
left_idx
:]
new_phns_mid
=
new_phns
[
left_idx
:]
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
# speech edit
else
:
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
idx
,
wrd
=
key
.
split
(
'_'
)
...
@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str,
...
@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str,
len
(
old_phns
))
len
(
old_phns
))
break
break
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
def
duration_adjust_factor
(
original_dur
:
List
[
int
],
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
duration_adjust_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
phns
:
List
[
str
]):
length
=
0
length
=
0
factor_list
=
[]
factor_list
=
[]
for
ori
,
pred
,
phn
in
zip
(
original
_dur
,
pred_dur
,
phns
):
for
ori
g
,
pred
,
phn
in
zip
(
orig
_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
if
pred
==
0
or
phn
==
'sp'
:
continue
continue
else
:
else
:
factor_list
.
append
(
ori
/
pred
)
factor_list
.
append
(
ori
g
/
pred
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
.
sort
()
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
if
len
(
factor_list
)
<
5
:
return
1
return
1
length
=
2
length
=
2
return
np
.
average
(
factor_list
[
length
:
-
length
])
avg
=
np
.
average
(
factor_list
[
length
:
-
length
])
return
avg
def
prepare_features_with_duration
(
uid
:
str
,
def
prep_feats_with_dur
(
wav_path
:
str
,
prefix
:
str
,
wav_path
:
str
,
mlm_model
:
nn
.
Layer
,
mlm_model
:
nn
.
Layer
,
source_lang
:
str
=
"English"
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
mask_reconstruct
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
fs
:
int
=
24000
,
wav_org
,
rate
=
librosa
.
load
(
hop_length
:
int
=
300
):
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
'''
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
Returns:
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
=
wav_path
,
wav_path
=
wav_path
,
...
@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str,
...
@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str,
if
start_end_sp
:
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
new_phns
=
new_phns
+
[
'sp'
]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
==
"english"
:
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
target_lang
)
old_durs
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
elif
target_lang
==
"chinese"
:
if
source_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
elif
source_lang
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
else
:
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"calculate duration_predict is not support for this language..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
orig
inal_old_duration
s
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
orig
_old_dur
s
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
new_phns
=
old_phns
span_to_add
=
span_to_repl
span_to_add
=
span_to_repl
d_factor_left
=
duration_adjust_factor
(
d_factor_left
=
duration_adjust_factor
(
original_old_durations
[:
span_to_repl
[
0
]],
orig_dur
=
orig_old_durs
[:
span_to_repl
[
0
]],
old_durations
[:
span_to_repl
[
0
]],
old_phns
[:
span_to_repl
[
0
]])
pred_dur
=
old_durs
[:
span_to_repl
[
0
]],
phns
=
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
duration_adjust_factor
(
d_factor_right
=
duration_adjust_factor
(
original_old_durations
[
span_to_repl
[
1
]:],
orig_dur
=
orig_old_durs
[
span_to_repl
[
1
]:],
old_durations
[
span_to_repl
[
1
]:],
old_phns
[
span_to_repl
[
1
]:])
pred_dur
=
old_durs
[
span_to_repl
[
1
]:],
phns
=
old_phns
[
span_to_repl
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_dur
ations_adjusted
=
[
d_factor
*
i
for
i
in
old_duration
s
]
new_dur
s_adjusted
=
[
d_factor
*
i
for
i
in
old_dur
s
]
else
:
else
:
if
duration_adjust
:
if
duration_adjust
:
d_factor
=
duration_adjust_factor
(
original_old_durations
,
d_factor
=
duration_adjust_factor
(
old_durations
,
old_phns
)
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
print
(
"d_factor:"
,
d_factor
)
d_factor
=
d_factor
*
1.25
d_factor
=
d_factor
*
1.25
else
:
else
:
d_factor
=
1
d_factor
=
1
if
target_lang
==
"english"
:
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
new_durations
=
evaluate_durations
(
new_durs
=
evaluate_durations
(
new_phns
,
target_lang
=
target_lang
)
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
elif
target_lang
==
"chinese"
:
"calculate duration_predict is not support for this language..."
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
=
target_lang
)
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
new_durs
]
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
new_durations
]
new_span_dur_sum
=
sum
(
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_dur_sum
=
sum
(
orig_old_durs
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
if
span_to_repl
[
0
]
<
len
(
old_phns
)
and
old_phns
[
span_to_repl
[
dur_offset
=
new_span_dur_sum
-
old_span_dur_sum
0
]]
==
new_phns
[
span_to_add
[
0
]]:
new_durations_adjusted
[
span_to_add
[
0
]]
=
original_old_durations
[
span_to_repl
[
0
]]
if
span_to_repl
[
1
]
<
len
(
old_phns
)
and
span_to_add
[
1
]
<
len
(
new_phns
):
if
old_phns
[
span_to_repl
[
1
]]
==
new_phns
[
span_to_add
[
1
]]:
new_durations_adjusted
[
span_to_add
[
1
]]
=
original_old_durations
[
span_to_repl
[
1
]]
new_span_duration_sum
=
sum
(
new_durations_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_duration_sum
=
sum
(
original_old_durations
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
duration_offset
=
new_span_duration_sum
-
old_span_duration_sum
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
for
i
in
new_dur
ation
s_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
for
i
in
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
new_mfa_end
.
append
(
i
)
else
:
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
i
+
dur
ation
_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_start
+=
[
i
+
dur_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur
ation
_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
# 3. get new wav
# 在原始句子后拼接
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
left_idx
=
len
(
wav_org
)
left_idx
=
len
(
wav_org
)
right_idx
=
left_idx
right_idx
=
left_idx
# 在原始句子中间替换
else
:
else
:
left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
new_blank_wav
=
np
.
zeros
(
blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_duration_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
(
int
(
np
.
ceil
(
new_span_dur_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
new_wav_org
=
np
.
concatenate
(
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
[
wav_org
[:
left_idx
],
new_blank_wav
,
wav_org
[
right_idx
:]])
new_wav
=
np
.
concatenate
(
[
wav_org
[:
left_idx
],
blank_wav
,
wav_org
[
right_idx
:]])
# 4. get old and new mel span to be mask
# 4. get old and new mel span to be mask
# [92, 92]
# [92, 92]
old_span_bdy
=
get_masked_mel_bdy
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_to_repl
)
old_span_bdy
,
mfa_start
,
mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
mfa_start
,
mfa_end
=
mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_repl
)
# [92, 174]
# [92, 174]
new_span_bdy
=
get_masked_mel_bdy
(
new_mfa_start
,
new_mfa_end
,
fs
,
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
hop_length
,
span_to_add
)
new_span_bdy
,
new_mfa_start
,
new_mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
new_mfa_start
,
mfa_end
=
new_mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_add
)
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
# old_span_bdy, new_span_bdy 是帧级别的范围
return
new_wav
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prepare_features
(
uid
:
str
,
def
prep_feats
(
mlm_model
:
nn
.
Layer
,
mlm_model
:
nn
.
Layer
,
processor
,
wav_path
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
duration_adjust
:
bool
=
True
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
train_args
=
None
):
fs
:
int
=
24000
,
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prepare_features_with_duration
(
hop_length
:
int
=
300
,
uid
=
uid
,
token_list
:
List
[
str
]
=
[]):
prefix
=
prefix
,
wav
,
phns
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prep_feats_with_dur
(
source_lang
=
source_lang
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
mlm_model
=
mlm_model
,
old_str
=
old_str
,
old_str
=
old_str
,
new_str
=
new_str
,
new_str
=
new_str
,
wav_path
=
wav_path
,
wav_path
=
wav_path
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
mask_reconstruct
=
mask_reconstruct
,
train_args
=
train_args
)
fs
=
fs
,
speech
=
wav_org
hop_length
=
hop_length
)
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns
)))
span_bdy
=
np
.
array
(
new_span_bdy
)
span_bdy
=
np
.
array
(
new_span_bdy
)
batch
=
[(
'1'
,
{
batch
=
[(
'1'
,
{
"speech"
:
speech
,
"speech"
:
wav
,
"align_start"
:
align
_start
,
"align_start"
:
mfa
_start
,
"align_end"
:
align
_end
,
"align_end"
:
mfa
_end
,
"text"
:
text
,
"text"
:
text
,
"span_bdy"
:
span_bdy
"span_bdy"
:
span_bdy
})]
})]
...
@@ -648,375 +505,135 @@ def prepare_features(uid: str,
...
@@ -648,375 +505,135 @@ def prepare_features(uid: str,
return
batch
,
old_span_bdy
,
new_span_bdy
return
batch
,
old_span_bdy
,
new_span_bdy
def
decode_with_model
(
uid
:
str
,
def
decode_with_model
(
mlm_model
:
nn
.
Layer
,
mlm_model
:
nn
.
Layer
,
processor
,
collate_fn
,
collate_fn
,
wav_path
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
fs
:
int
=
24000
,
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
hop_length
:
int
=
300
,
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
token_list
:
List
[
str
]
=
[]):
batch
,
old_span_bdy
,
new_span_bdy
=
prep_feats
(
batch
,
old_span_bdy
,
new_span_bdy
=
prepare_features
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
mlm_model
=
mlm_model
,
processor
=
processor
,
wav_path
=
wav_path
,
wav_path
=
wav_path
,
old_str
=
old_str
,
old_str
=
old_str
,
new_str
=
new_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
fs
=
fs
,
hop_length
=
hop_length
,
token_list
=
token_list
)
feats
=
collate_fn
(
batch
)[
1
]
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos'
in
feats
.
keys
():
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
feats
.
pop
(
'text_masked_pos'
)
for
k
,
v
in
feats
.
items
():
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
output
=
mlm_model
.
inference
(
rtn
=
mlm_model
.
inference
(
text
=
feats
[
'text'
],
**
feats
,
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
speech
=
feats
[
'speech'
],
output
=
rtn
[
'feat_gen'
]
masked_pos
=
feats
[
'masked_pos'
],
speech_mask
=
feats
[
'speech_mask'
],
text_mask
=
feats
[
'text_mask'
],
speech_seg_pos
=
feats
[
'speech_seg_pos'
],
text_seg_pos
=
feats
[
'text_seg_pos'
],
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
()],
axis
=
0
)
.
cpu
()
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
()],
axis
=
0
)
elif
0
not
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
elif
0
not
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output_feat
=
paddle
.
concat
(
[
output
[
0
].
squeeze
()]
+
output
[
1
:
-
1
],
axis
=
0
)
.
cpu
()
[
output
[
0
].
squeeze
()]
+
output
[
1
:
-
1
],
axis
=
0
)
elif
0
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
elif
0
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
],
axis
=
0
)
.
cpu
()
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
],
axis
=
0
)
else
:
else
:
output_feat
=
paddle
.
concat
(
output_feat
=
paddle
.
concat
(
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
axis
=
0
).
cpu
()
axis
=
0
)
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
return
wav_org
,
None
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
sega_emb
=
sega_emb
self
.
duration_collect
=
duration_collect
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
sega_emb
=
self
.
sega_emb
,
duration_collect
=
self
.
duration_collect
,
text_masking
=
self
.
text_masking
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
if
'text'
not
in
output
:
text
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
.
unsqueeze
(
-
1
)))
-
2
text_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
+
1
max_tlen
=
1
align_start
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_end
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_start_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
sega_emb
=
False
mean_phn_span
=
0
mlm_prob
=
0.15
else
:
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
feats_extract
.
hop_length
).
int
()
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
feats_extract
.
hop_length
).
int
()
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
if
attention_window
>
0
and
pad_speech
:
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
speech_pad
,
max_slen
,
max_slen
,
attention_window
)
max_len
=
max_slen
+
max_tlen
if
attention_window
>
0
:
text_pad
,
max_tlen
=
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
)
else
:
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
if
attention_window
>
0
:
text_mask
=
text_mask
*
2
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
if
text_masking
:
masked_pos
,
text_masked_pos
,
_
=
phones_text_masking
(
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
else
:
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
masked_pos
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
output_dict
=
{}
if
duration_collect
and
'text'
in
output
:
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
=
get_seg_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
,
masked_pos
,
feats_lens
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:
reordered_idx
.
shape
[
1
],
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
output_dict
[
'durations'
]
=
durations
output_dict
[
'reordered_idx'
]
=
reordered_idx
else
:
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output_dict
[
'speech_lens'
]
=
output
[
"speech_lens"
]
output_dict
[
'text_lens'
]
=
text_lens
output
=
(
uttids
,
output_dict
)
return
output
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class
=
LogMelFBank
if
args
.
feats_extract_conf
[
'win_length'
]
is
None
:
args
.
feats_extract_conf
[
'win_length'
]
=
args
.
feats_extract_conf
[
'n_fft'
]
args_dic
=
{}
for
k
,
v
in
args
.
feats_extract_conf
.
items
():
if
k
==
'fs'
:
args_dic
[
'sr'
]
=
v
else
:
args_dic
[
k
]
=
v
feats_extract
=
feats_extract_class
(
**
args_dic
)
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
if
args
.
encoder_conf
[
'selfattention_layer_type'
]
==
'longformer'
:
attention_window
=
args
.
encoder_conf
[
'attention_window'
]
pad_speech
=
True
if
'pre_speech_layer'
in
args
.
encoder_conf
and
args
.
encoder_conf
[
'pre_speech_layer'
]
>
0
else
False
else
:
attention_window
=
0
pad_speech
=
False
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
(
)
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
duration_collect
=
True
else
:
duration_collect
=
False
return
MLMCollateFn
(
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
feats_extract
,
return
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
args
.
model_conf
[
'mlm_prob'
]
*
mlm_prob_factor
,
mean_phn_span
=
args
.
model_conf
[
'mean_phn_span'
],
attention_window
=
attention_window
,
pad_speech
=
pad_speech
,
sega_emb
=
sega_emb
,
duration_collect
=
duration_collect
)
def
get_mlm_output
(
uid
:
str
,
def
get_mlm_output
(
wav_path
:
str
,
wav_path
:
str
,
model_name
:
str
=
"paddle_checkpoint_en"
,
prefix
:
str
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
source_lang
:
str
=
"english"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
):
start_end_sp
:
bool
=
False
):
mlm_model
,
train_
args
=
load_model
(
model_name
)
mlm_model
,
train_
conf
=
load_model
(
model_name
)
mlm_model
.
eval
()
mlm_model
.
eval
()
processor
=
None
collate_fn
=
build_collate_fn
(
train_args
,
False
)
collate_fn
=
build_collate_fn
(
sr
=
train_conf
.
feats_extract_conf
[
'fs'
],
n_fft
=
train_conf
.
feats_extract_conf
[
'n_fft'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
win_length
=
train_conf
.
feats_extract_conf
[
'win_length'
],
n_mels
=
train_conf
.
feats_extract_conf
[
'n_mels'
],
fmin
=
train_conf
.
feats_extract_conf
[
'fmin'
],
fmax
=
train_conf
.
feats_extract_conf
[
'fmax'
],
mlm_prob
=
train_conf
[
'mlm_prob'
],
mean_phn_span
=
train_conf
[
'mean_phn_span'
],
train
=
False
,
seg_emb
=
train_conf
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
)
return
decode_with_model
(
return
decode_with_model
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
mlm_model
=
mlm_model
,
processor
=
processor
,
collate_fn
=
collate_fn
,
collate_fn
=
collate_fn
,
wav_path
=
wav_path
,
wav_path
=
wav_path
,
old_str
=
old_str
,
old_str
=
old_str
,
new_str
=
new_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
decoder
=
decoder
,
use_teacher_forcing
=
use_teacher_forcing
,
use_teacher_forcing
=
use_teacher_forcing
,
duration_adjust
=
duration_adjust
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
fs
=
train_conf
.
feats_extract_conf
[
'fs'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
token_list
=
train_conf
.
token_list
)
def
evaluate
(
uid
:
str
,
def
evaluate
(
uid
:
str
,
source_lang
:
str
=
"english"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
use_pt_vocoder
:
bool
=
False
,
use_pt_vocoder
:
bool
=
False
,
prefix
:
str
=
"./prompt/dev/"
,
prefix
:
os
.
PathLike
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
new_str
:
str
=
""
,
prompt_decoding
:
bool
=
False
,
prompt_decoding
:
bool
=
False
,
task_name
:
str
=
None
):
task_name
:
str
=
None
):
duration_preditor_path
=
None
# get origin text and path of origin wav
spemd
=
None
old_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
full_origin_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
if
task_name
==
'edit'
:
new_str
=
new_str
new_str
=
new_str
elif
task_name
==
'synthesize'
:
elif
task_name
==
'synthesize'
:
new_str
=
full_origin
_str
+
new_str
new_str
=
old
_str
+
new_str
else
:
else
:
new_str
=
full_origin_str
+
' '
.
join
(
new_str
=
old_str
+
' '
.
join
([
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
[
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
print
(
'new_str is '
,
new_str
)
print
(
'new_str is '
,
new_str
)
if
not
old_str
:
old_str
=
full_origin_str
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
model_name
=
model_name
,
wav_path
=
wav_path
,
wav_path
=
wav_path
,
full_origin_str
=
full_origin_str
,
old_str
=
old_str
,
old_str
=
old_str
,
new_str
=
new_str
,
new_str
=
new_str
,
use_pt_vocoder
=
use_pt_vocoder
,
use_pt_vocoder
=
use_pt_vocoder
)
duration_preditor_path
=
duration_preditor_path
,
sid
=
spemd
)
return
results_dict
return
results_dict
...
...
ernie-sat/m
odel_paddle
.py
→
ernie-sat/m
lm
.py
浏览文件 @
9224659c
import
argparse
import
argparse
import
logging
import
math
import
os
import
os
import
sys
import
sys
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
...
@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath):
...
@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath):
if
os
.
path
.
isdir
(
dir_path
):
if
os
.
path
.
isdir
(
dir_path
):
sys
.
path
.
append
(
dir_path
)
sys
.
path
.
append
(
dir_path
)
from
paddlespeech.s2t.utils.error_rate
import
ErrorCalculator
from
paddlespeech.t2s.modules.activation
import
get_activation
from
paddlespeech.t2s.modules.activation
import
get_activation
from
paddlespeech.t2s.modules.conformer.convolution
import
ConvolutionModule
from
paddlespeech.t2s.modules.conformer.convolution
import
ConvolutionModule
from
paddlespeech.t2s.modules.conformer.encoder_layer
import
EncoderLayer
from
paddlespeech.t2s.modules.conformer.encoder_layer
import
EncoderLayer
from
paddlespeech.t2s.modules.masked_fill
import
masked_fill
from
paddlespeech.t2s.modules.masked_fill
import
masked_fill
from
paddlespeech.t2s.modules.nets_utils
import
initialize
from
paddlespeech.t2s.modules.nets_utils
import
initialize
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.transformer.embedding
import
LegacyRelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
PositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
PositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
ScaledPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
ScaledPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
RelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
RelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.subsampling
import
Conv2dSubsampling
from
paddlespeech.t2s.modules.transformer.subsampling
import
Conv2dSubsampling
from
paddlespeech.t2s.modules.transformer.attention
import
LegacyRelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
MultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
MultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
from
paddlespeech.t2s.modules.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
...
@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo
...
@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo
from
paddlespeech.t2s.modules.transformer.repeat
import
repeat
from
paddlespeech.t2s.modules.transformer.repeat
import
repeat
from
paddlespeech.t2s.modules.layer_norm
import
LayerNorm
from
paddlespeech.t2s.modules.layer_norm
import
LayerNorm
from
yacs.config
import
CfgNode
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
extend_pe
(
self
,
x
):
"""Reset the positional encodings."""
if
self
.
pe
is
not
None
:
if
paddle
.
shape
(
self
.
pe
)[
1
]
>=
paddle
.
shape
(
x
)[
1
]:
return
pe
=
paddle
.
zeros
((
paddle
.
shape
(
x
)[
1
],
self
.
d_model
))
if
self
.
reverse
:
position
=
paddle
.
arange
(
paddle
.
shape
(
x
)[
1
]
-
1
,
-
1
,
-
1.0
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
else
:
position
=
paddle
.
arange
(
0
,
paddle
.
shape
(
x
)[
1
],
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
pe
=
pe
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
extend_pe
(
x
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
:
paddle
.
shape
(
x
)[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
# MLM -> Mask Language Model
class
mySequential
(
nn
.
Sequential
):
class
mySequential
(
nn
.
Sequential
):
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
for
module
in
self
.
_sub_layers
.
values
():
for
module
in
self
.
_sub_layers
.
values
():
...
@@ -108,12 +51,8 @@ class mySequential(nn.Sequential):
...
@@ -108,12 +51,8 @@ class mySequential(nn.Sequential):
return
inputs
return
inputs
class
NewMaskInputLayer
(
nn
.
Layer
):
class
MaskInputLayer
(
nn
.
Layer
):
__constants__
=
[
'out_features'
]
def
__init__
(
self
,
out_features
:
int
)
->
None
:
out_features
:
int
def
__init__
(
self
,
out_features
:
int
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
mask_feature
=
paddle
.
create_parameter
(
self
.
mask_feature
=
paddle
.
create_parameter
(
shape
=
(
1
,
1
,
out_features
),
shape
=
(
1
,
1
,
out_features
),
...
@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer):
...
@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer):
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
=
None
)
->
paddle
.
Tensor
:
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
=
None
)
->
paddle
.
Tensor
:
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
return
masked_input
return
masked_input
class
LegacyRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
pos_bias_v
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b
,
h
,
t1
,
t2
=
paddle
.
shape
(
x
)
zero_pad
=
paddle
.
zeros
((
b
,
h
,
t1
,
1
))
x_padded
=
paddle
.
concat
([
zero_pad
,
x
],
axis
=-
1
)
x_padded
=
paddle
.
reshape
(
x_padded
,
[
b
,
h
,
t2
+
1
,
t1
])
# only keep the positions from 0 to time2
x
=
paddle
.
reshape
(
x_padded
[:,
:,
1
:],
[
b
,
h
,
t1
,
t2
])
if
self
.
zero_triu
:
ones
=
paddle
.
ones
((
t1
,
t2
))
x
=
x
*
paddle
.
tril
(
ones
,
t2
-
1
)[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
mask
):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# (batch, time1, head, d_k)
q
=
paddle
.
transpose
(
q
,
[
0
,
2
,
1
,
3
])
n_batch_pos
=
paddle
.
shape
(
pos_emb
)[
0
]
p
=
paddle
.
reshape
(
self
.
linear_pos
(
pos_emb
),
[
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
])
# (batch, head, time1, d_k)
p
=
paddle
.
transpose
(
p
,
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_u
=
paddle
.
transpose
((
q
+
self
.
pos_bias_u
),
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_v
=
paddle
.
transpose
((
q
+
self
.
pos_bias_v
),
[
0
,
2
,
1
,
3
])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
paddle
.
transpose
(
p
,
[
0
,
1
,
3
,
2
]))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
# (batch, head, time1, time2)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
class
MLMEncoder
(
nn
.
Layer
):
class
MLMEncoder
(
nn
.
Layer
):
"""Conformer encoder module.
"""Conformer encoder module.
...
@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer):
...
@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer):
cnn_module_kernel (int): Kernerl size of convolution module.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
idim
,
idim
:
int
,
vocab_size
=
0
,
vocab_size
:
int
=
0
,
pre_speech_layer
:
int
=
0
,
pre_speech_layer
:
int
=
0
,
attention_dim
=
256
,
attention_dim
:
int
=
256
,
attention_heads
=
4
,
attention_heads
:
int
=
4
,
linear_units
=
2048
,
linear_units
:
int
=
2048
,
num_blocks
=
6
,
num_blocks
:
int
=
6
,
dropout_rate
=
0.1
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
=
0.0
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
=
"conv2d"
,
input_layer
:
str
=
"conv2d"
,
normalize_before
=
True
,
normalize_before
:
bool
=
True
,
concat_after
=
False
,
concat_after
:
bool
=
False
,
positionwise_layer_type
=
"linear"
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_conv_kernel_size
=
1
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
=
False
,
macaron_style
:
bool
=
False
,
pos_enc_layer_type
=
"abs_pos"
,
pos_enc_layer_type
:
str
=
"abs_pos"
,
pos_enc_class
=
None
,
pos_enc_class
=
None
,
selfattention_layer_type
=
"selfattn"
,
selfattention_layer_type
:
str
=
"selfattn"
,
activation_type
=
"swish"
,
activation_type
:
str
=
"swish"
,
use_cnn_module
=
False
,
use_cnn_module
:
bool
=
False
,
zero_triu
=
False
,
zero_triu
:
bool
=
False
,
cnn_module_kernel
=
31
,
cnn_module_kernel
:
int
=
31
,
padding_idx
=-
1
,
padding_idx
:
int
=-
1
,
stochastic_depth_rate
=
0.0
,
stochastic_depth_rate
:
float
=
0.0
,
intermediate_layers
=
None
,
text_masking
:
bool
=
False
):
text_masking
=
False
):
"""Construct an Encoder object."""
"""Construct an Encoder object."""
super
().
__init__
()
super
().
__init__
()
self
.
_output_size
=
attention_dim
self
.
_output_size
=
attention_dim
self
.
text_masking
=
text_masking
self
.
text_masking
=
text_masking
if
self
.
text_masking
:
if
self
.
text_masking
:
self
.
text_masking_layer
=
New
MaskInputLayer
(
attention_dim
)
self
.
text_masking_layer
=
MaskInputLayer
(
attention_dim
)
activation
=
get_activation
(
activation_type
)
activation
=
get_activation
(
activation_type
)
if
pos_enc_layer_type
==
"abs_pos"
:
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
pos_enc_class
=
PositionalEncoding
...
@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer):
...
@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer):
elif
input_layer
==
"mlm"
:
elif
input_layer
==
"mlm"
:
self
.
segment_emb
=
None
self
.
segment_emb
=
None
self
.
speech_embed
=
mySequential
(
self
.
speech_embed
=
mySequential
(
New
MaskInputLayer
(
idim
),
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
nn
.
ReLU
(),
...
@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer):
...
@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer):
self
.
segment_emb
=
nn
.
Embedding
(
self
.
segment_emb
=
nn
.
Embedding
(
500
,
attention_dim
,
padding_idx
=
padding_idx
)
500
,
attention_dim
,
padding_idx
=
padding_idx
)
self
.
speech_embed
=
mySequential
(
self
.
speech_embed
=
mySequential
(
New
MaskInputLayer
(
idim
),
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
nn
.
ReLU
(),
...
@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer):
...
@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer):
# self-attention module definition
# self-attention module definition
if
selfattention_layer_type
==
"selfattn"
:
if
selfattention_layer_type
==
"selfattn"
:
logging
.
info
(
"encoder self-attention layer type = self-attention"
)
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
attention_dropout_rate
,
)
...
@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer):
...
@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer):
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"rel_selfattn"
:
elif
selfattention_layer_type
==
"rel_selfattn"
:
logging
.
info
(
"encoder self-attention layer type = relative self-attention"
)
assert
pos_enc_layer_type
==
"rel_pos"
assert
pos_enc_layer_type
==
"rel_pos"
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
...
@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer):
...
@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer):
if
self
.
normalize_before
:
if
self
.
normalize_before
:
self
.
after_norm
=
LayerNorm
(
attention_dim
)
self
.
after_norm
=
LayerNorm
(
attention_dim
)
self
.
intermediate_layers
=
intermediate_layers
def
forward
(
self
,
def
forward
(
self
,
speech
_pad
,
speech
:
paddle
.
Tensor
,
text
_pad
,
text
:
paddle
.
Tensor
,
masked_pos
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
=
None
,
speech_mask
:
paddle
.
Tensor
=
None
,
text_mask
=
None
,
text_mask
:
paddle
.
Tensor
=
None
,
speech_seg_pos
=
None
,
speech_seg_pos
:
paddle
.
Tensor
=
None
,
text_seg_pos
=
None
):
text_seg_pos
:
paddle
.
Tensor
=
None
):
"""Encode input sequence.
"""Encode input sequence.
"""
"""
if
masked_pos
is
not
None
:
if
masked_pos
is
not
None
:
speech
_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
)
speech
=
self
.
speech_embed
(
speech
,
masked_pos
)
else
:
else
:
speech_pad
=
self
.
speech_embed
(
speech_pad
)
speech
=
self
.
speech_embed
(
speech
)
# pure speech input
if
text
is
not
None
:
if
-
2
in
np
.
array
(
text_pad
):
text
=
self
.
text_embed
(
text
)
text_pad
=
text_pad
+
3
text_mask
=
paddle
.
unsqueeze
(
bool
(
text_pad
),
1
)
text_seg_pos
=
paddle
.
zeros_like
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg_pos
),
text_pad
[
1
])
text_seg_pos
=
None
elif
text_pad
is
not
None
:
text_pad
=
self
.
text_embed
(
text_pad
)
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text
_pad
=
(
text_pad
[
0
]
+
text_seg_emb
,
text_pad
[
1
])
text
=
(
text
[
0
]
+
text_seg_emb
,
text
[
1
])
speech
_pad
=
(
speech_pad
[
0
]
+
speech_seg_emb
,
speech_pad
[
1
])
speech
=
(
speech
[
0
]
+
speech_seg_emb
,
speech
[
1
])
if
self
.
pre_speech_encoders
:
if
self
.
pre_speech_encoders
:
speech
_pad
,
_
=
self
.
pre_speech_encoders
(
speech_pad
,
speech_mask
)
speech
,
_
=
self
.
pre_speech_encoders
(
speech
,
speech_mask
)
if
text
_pad
is
not
None
:
if
text
is
not
None
:
xs
=
paddle
.
concat
([
speech
_pad
[
0
],
text_pad
[
0
]],
axis
=
1
)
xs
=
paddle
.
concat
([
speech
[
0
],
text
[
0
]],
axis
=
1
)
xs_pos_emb
=
paddle
.
concat
([
speech
_pad
[
1
],
text_pad
[
1
]],
axis
=
1
)
xs_pos_emb
=
paddle
.
concat
([
speech
[
1
],
text
[
1
]],
axis
=
1
)
masks
=
paddle
.
concat
([
speech_mask
,
text_mask
],
axis
=-
1
)
masks
=
paddle
.
concat
([
speech_mask
,
text_mask
],
axis
=-
1
)
else
:
else
:
xs
=
speech
_pad
[
0
]
xs
=
speech
[
0
]
xs_pos_emb
=
speech
_pad
[
1
]
xs_pos_emb
=
speech
[
1
]
masks
=
speech_mask
masks
=
speech_mask
xs
,
masks
=
self
.
encoders
((
xs
,
xs_pos_emb
),
masks
)
xs
,
masks
=
self
.
encoders
((
xs
,
xs_pos_emb
),
masks
)
...
@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer):
...
@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer):
class
MLMDecoder
(
MLMEncoder
):
class
MLMDecoder
(
MLMEncoder
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
=
None
,
segment_emb
=
None
):
def
forward
(
self
,
xs
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
):
"""Encode input sequence.
"""Encode input sequence.
Args:
Args:
...
@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder):
...
@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: Mask tensor (#batch, time).
"""
"""
if
not
self
.
training
:
masked_pos
=
None
xs
=
self
.
embed
(
xs
)
xs
=
self
.
embed
(
xs
)
if
segment_emb
:
xs
=
(
xs
[
0
]
+
segment_emb
,
xs
[
1
])
if
self
.
intermediate_layers
is
None
:
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
else
:
intermediate_outputs
=
[]
for
layer_idx
,
encoder_layer
in
enumerate
(
self
.
encoders
):
xs
,
masks
=
encoder_layer
(
xs
,
masks
)
if
(
self
.
intermediate_layers
is
not
None
and
layer_idx
+
1
in
self
.
intermediate_layers
):
encoder_output
=
xs
# intermediate branches also require normalization.
if
self
.
normalize_before
:
encoder_output
=
self
.
after_norm
(
encoder_output
)
intermediate_outputs
.
append
(
encoder_output
)
if
isinstance
(
xs
,
tuple
):
if
isinstance
(
xs
,
tuple
):
xs
=
xs
[
0
]
xs
=
xs
[
0
]
if
self
.
normalize_before
:
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
xs
=
self
.
after_norm
(
xs
)
if
self
.
intermediate_layers
is
not
None
:
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
return
xs
,
masks
def
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
):
# encoder and decoder is nn.Layer, not str
round
=
max_len
%
attention_window
class
MLM
(
nn
.
Layer
):
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
(
shape
=
(
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
class
MLMModel
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
token_list
:
Union
[
Tuple
[
str
,
...],
List
[
str
]],
token_list
:
Union
[
Tuple
[
str
,
...],
List
[
str
]],
odim
:
int
,
odim
:
int
,
...
@@ -557,44 +350,15 @@ class MLMModel(nn.Layer):
...
@@ -557,44 +350,15 @@ class MLMModel(nn.Layer):
postnet_layers
:
int
=
0
,
postnet_layers
:
int
=
0
,
postnet_chans
:
int
=
0
,
postnet_chans
:
int
=
0
,
postnet_filts
:
int
=
0
,
postnet_filts
:
int
=
0
,
ignore_id
:
int
=-
1
,
text_masking
:
bool
=
False
):
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
report_cer
:
bool
=
True
,
report_wer
:
bool
=
True
,
sym_space
:
str
=
"<space>"
,
sym_blank
:
str
=
"<blank>"
,
masking_schema
:
str
=
"span"
,
mean_phn_span
:
int
=
3
,
mlm_prob
:
float
=
0.25
,
dynamic_mlm_prob
=
False
,
decoder_seg_pos
=
False
,
text_masking
=
False
):
super
().
__init__
()
super
().
__init__
()
# note that eos is the same as sos (equivalent ID)
self
.
odim
=
odim
self
.
odim
=
odim
self
.
ignore_id
=
ignore_id
self
.
token_list
=
token_list
.
copy
()
self
.
token_list
=
token_list
.
copy
()
self
.
encoder
=
encoder
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
decoder
=
decoder
self
.
vocab_size
=
encoder
.
text_embed
[
0
].
_num_embeddings
self
.
vocab_size
=
encoder
.
text_embed
[
0
].
_num_embeddings
if
report_cer
or
report_wer
:
self
.
error_calculator
=
ErrorCalculator
(
token_list
,
sym_space
,
sym_blank
,
report_cer
,
report_wer
)
else
:
self
.
error_calculator
=
None
self
.
mlm_weight
=
1.0
self
.
mlm_prob
=
mlm_prob
self
.
mlm_layer
=
12
self
.
finetune_wo_mlm
=
True
self
.
max_span
=
50
self
.
min_span
=
4
self
.
mean_phn_span
=
mean_phn_span
self
.
masking_schema
=
masking_schema
if
self
.
decoder
is
None
or
not
(
hasattr
(
self
.
decoder
,
if
self
.
decoder
is
None
or
not
(
hasattr
(
self
.
decoder
,
'output_layer'
)
and
'output_layer'
)
and
self
.
decoder
.
output_layer
is
not
None
):
self
.
decoder
.
output_layer
is
not
None
):
...
@@ -606,15 +370,9 @@ class MLMModel(nn.Layer):
...
@@ -606,15 +370,9 @@ class MLMModel(nn.Layer):
self
.
encoder
.
text_embed
[
0
].
_embedding_dim
,
self
.
encoder
.
text_embed
[
0
].
_embedding_dim
,
self
.
vocab_size
,
self
.
vocab_size
,
weight_attr
=
self
.
encoder
.
text_embed
[
0
].
_weight_attr
)
weight_attr
=
self
.
encoder
.
text_embed
[
0
].
_weight_attr
)
self
.
text_mlm_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_id
)
else
:
else
:
self
.
text_sfc
=
None
self
.
text_sfc
=
None
self
.
text_mlm_loss
=
None
self
.
decoder_seg_pos
=
decoder_seg_pos
if
lsm_weight
>
50
:
self
.
l1_loss_func
=
nn
.
MSELoss
()
else
:
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
self
.
encoder
.
_output_size
,
idim
=
self
.
encoder
.
_output_size
,
odim
=
odim
,
odim
=
odim
,
...
@@ -624,119 +382,77 @@ class MLMModel(nn.Layer):
...
@@ -624,119 +382,77 @@ class MLMModel(nn.Layer):
use_batch_norm
=
True
,
use_batch_norm
=
True
,
dropout_rate
=
0.5
,
))
dropout_rate
=
0.5
,
))
def
collect_feats
(
self
,
speech
,
speech_lens
,
text
,
text_lens
,
masked_pos
,
speech_mask
,
text_mask
,
speech_seg_pos
,
text_seg_pos
,
y_masks
=
None
)
->
Dict
[
str
,
paddle
.
Tensor
]:
return
{
"feats"
:
speech
,
"feats_lens"
:
speech_lens
}
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
if
self
.
decoder
is
not
None
:
ys_in
=
self
.
_add_first_frame_and_remove_last_frame
(
batch
[
'speech_pad'
])
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
ys_in
,
y_masks
,
encoder_out
,
bool
(
h_masks
),
self
.
encoder
.
segment_emb
(
speech_seg_pos
))
speech_hidden_states
=
zs
else
:
speech_hidden_states
=
encoder_out
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
])[
1
],
:]
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
(
paddle
.
shape
(
speech_hidden_states
)[
0
],
-
1
,
self
.
odim
))
else
:
before_outs
=
speech_hidden_states
if
self
.
postnet
is
not
None
:
after_outs
=
before_outs
+
paddle
.
transpose
(
self
.
postnet
(
paddle
.
transpose
(
before_outs
,
[
0
,
2
,
1
])),
(
0
,
2
,
1
))
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos'
]
def
inference
(
def
inference
(
self
,
self
,
speech
,
speech
:
paddle
.
Tensor
,
text
,
text
:
paddle
.
Tensor
,
masked_pos
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
,
speech_mask
:
paddle
.
Tensor
,
text_mask
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
,
text_seg_pos
:
paddle
.
Tensor
,
span_bdy
,
span_bdy
:
List
[
int
],
y_masks
=
None
,
speech_lens
=
None
,
text_lens
=
None
,
feats
:
Optional
[
paddle
.
Tensor
]
=
None
,
spembs
:
Optional
[
paddle
.
Tensor
]
=
None
,
sids
:
Optional
[
paddle
.
Tensor
]
=
None
,
lids
:
Optional
[
paddle
.
Tensor
]
=
None
,
threshold
:
float
=
0.5
,
minlenratio
:
float
=
0.0
,
maxlenratio
:
float
=
10.0
,
use_teacher_forcing
:
bool
=
False
,
)
->
Dict
[
str
,
paddle
.
Tensor
]:
use_teacher_forcing
:
bool
=
False
,
)
->
Dict
[
str
,
paddle
.
Tensor
]:
'''
Args:
speech (paddle.Tensor): input speech (B, Tmax, D).
text (paddle.Tensor): input text (B, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (B, Tmax)
speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
batch
=
dict
(
outs
=
[
speech
[:,
:
span_bdy
[
0
]]]
speech_pad
=
speech
,
z_cache
=
None
text_pad
=
text
,
if
use_teacher_forcing
:
before_outs
,
zs
,
*
_
=
self
.
forward
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
,
)
text_seg_pos
=
text_seg_pos
)
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs
=
[
batch
[
'speech_pad'
][:,
:
span_bdy
[
0
]]]
z_cache
=
None
if
use_teacher_forcing
:
before
,
zs
,
_
,
_
=
self
.
forward
(
batch
,
speech_seg_pos
,
y_masks
=
y_masks
)
if
zs
is
None
:
if
zs
is
None
:
zs
=
before
zs
=
before
_outs
outs
+=
[
zs
[
0
][
span_bdy
[
0
]:
span_bdy
[
1
]]]
outs
+=
[
zs
[
0
][
span_bdy
[
0
]:
span_bdy
[
1
]]]
outs
+=
[
batch
[
'speech_pad'
]
[:,
span_bdy
[
1
]:]]
outs
+=
[
speech
[:,
span_bdy
[
1
]:]]
return
dict
(
feat_gen
=
outs
)
return
outs
return
None
return
None
def
_add_first_frame_and_remove_last_frame
(
self
,
ys
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
ys_in
=
paddle
.
concat
(
[
paddle
.
zeros
(
shape
=
(
paddle
.
shape
(
ys
)[
0
],
1
,
paddle
.
shape
(
ys
)[
2
]),
dtype
=
ys
.
dtype
),
ys
[:,
:
-
1
]
],
axis
=
1
)
return
ys_in
class
MLMEncAsDecoder
(
MLM
):
class
MLMEncAsDecoderModel
(
MLMModel
):
def
forward
(
self
,
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
):
# feats: (Batch, Length, Dim)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
encoder_out
,
h_masks
=
self
.
encoder
(
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
else
:
zs
=
encoder_out
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
]
)[
1
],
:]
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
sfc
is
not
None
:
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
self
.
sfc
(
speech_hidden_states
),
...
@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel):
...
@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel):
[
0
,
2
,
1
])
[
0
,
2
,
1
])
else
:
else
:
after_outs
=
None
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
return
before_outs
,
after_outs
,
None
'masked_pos'
]
class
MLMDualMaksing
(
MLM
):
class
MLMDualMaksingModel
(
MLMModel
):
def
forward
(
self
,
def
_calc_mlm_loss
(
self
,
speech
:
paddle
.
Tensor
,
before_outs
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
after_outs
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
text_outs
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
batch
):
text_mask
:
paddle
.
Tensor
,
xs_pad
=
batch
[
'speech_pad'
]
speech_seg_pos
:
paddle
.
Tensor
,
text_pad
=
batch
[
'text_pad'
]
text_seg_pos
:
paddle
.
Tensor
):
masked_pos
=
batch
[
'masked_pos'
]
text_masked_pos
=
batch
[
'text_masked_pos'
]
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
if
after_outs
is
not
None
:
loss
+=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
after_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text_pad
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
encoder_out
,
h_masks
=
self
.
encoder
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
else
:
zs
=
encoder_out
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
]
)[
1
],
:]
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
text_sfc
:
if
self
.
text_sfc
:
text_hiddent_states
=
zs
[:,
paddle
.
shape
(
batch
[
'speech_pad'
])[
text_hiddent_states
=
zs
[:,
paddle
.
shape
(
speech
)[
1
]:,
:]
1
]:,
:]
text_outs
=
paddle
.
reshape
(
text_outs
=
paddle
.
reshape
(
self
.
text_sfc
(
text_hiddent_states
),
self
.
text_sfc
(
text_hiddent_states
),
(
paddle
.
shape
(
text_hiddent_states
)[
0
],
-
1
,
self
.
vocab_size
))
(
paddle
.
shape
(
text_hiddent_states
)[
0
],
-
1
,
self
.
vocab_size
))
...
@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel):
...
@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel):
[
0
,
2
,
1
])
[
0
,
2
,
1
])
else
:
else
:
after_outs
=
None
after_outs
=
None
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
return
before_outs
,
after_outs
,
text_outs
def
build_model_from_file
(
config_file
,
model_file
):
def
build_model_from_file
(
config_file
,
model_file
):
state_dict
=
paddle
.
load
(
model_file
)
state_dict
=
paddle
.
load
(
model_file
)
model_class
=
MLMDualMaksing
Model
if
'conformer_combine_vctk_aishell3_dual_masking'
in
config_file
\
model_class
=
MLMDualMaksing
if
'conformer_combine_vctk_aishell3_dual_masking'
in
config_file
\
else
MLMEncAsDecoder
Model
else
MLMEncAsDecoder
# 构建模型
# 构建模型
args
=
yaml
.
safe_load
(
Path
(
config_file
).
open
(
"r"
,
encoding
=
"utf-8"
))
with
open
(
config_file
)
as
f
:
args
=
argparse
.
Namespace
(
**
args
)
conf
=
CfgNode
(
yaml
.
safe_load
(
f
))
model
=
build_model
(
conf
,
model_class
)
model
=
build_model
(
args
,
model_class
)
model
.
set_state_dict
(
state_dict
)
model
.
set_state_dict
(
state_dict
)
return
model
,
args
return
model
,
conf
def
build_model
(
args
:
argparse
.
Namespace
,
# select encoder and decoder here
model_class
=
MLMEncAsDecoderModel
)
->
MLMModel
:
def
build_model
(
args
:
argparse
.
Namespace
,
model_class
=
MLMEncAsDecoder
)
->
MLM
:
if
isinstance
(
args
.
token_list
,
str
):
if
isinstance
(
args
.
token_list
,
str
):
with
open
(
args
.
token_list
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
args
.
token_list
,
encoding
=
"utf-8"
)
as
f
:
token_list
=
[
line
.
rstrip
()
for
line
in
f
]
token_list
=
[
line
.
rstrip
()
for
line
in
f
]
...
@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace,
...
@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace,
token_list
=
list
(
args
.
token_list
)
token_list
=
list
(
args
.
token_list
)
else
:
else
:
raise
RuntimeError
(
"token_list must be str or list"
)
raise
RuntimeError
(
"token_list must be str or list"
)
vocab_size
=
len
(
token_list
)
logging
.
info
(
f
"Vocabulary size:
{
vocab_size
}
"
)
vocab_size
=
len
(
token_list
)
odim
=
80
odim
=
80
pos_enc_class
=
ScaledPositionalEncoding
if
args
.
use_scaled_pos_enc
else
PositionalEncoding
pos_enc_class
=
ScaledPositionalEncoding
if
args
.
use_scaled_pos_enc
else
PositionalEncoding
...
@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace,
...
@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace,
if
conformer_rel_pos_type
==
"legacy"
:
if
conformer_rel_pos_type
==
"legacy"
:
if
conformer_pos_enc_layer_type
==
"rel_pos"
:
if
conformer_pos_enc_layer_type
==
"rel_pos"
:
conformer_pos_enc_layer_type
=
"legacy_rel_pos"
conformer_pos_enc_layer_type
=
"legacy_rel_pos"
logging
.
warning
(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
if
conformer_self_attn_layer_type
==
"rel_selfattn"
:
if
conformer_self_attn_layer_type
==
"rel_selfattn"
:
conformer_self_attn_layer_type
=
"legacy_rel_selfattn"
conformer_self_attn_layer_type
=
"legacy_rel_selfattn"
logging
.
warning
(
"Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
elif
conformer_rel_pos_type
==
"latest"
:
elif
conformer_rel_pos_type
==
"latest"
:
assert
conformer_pos_enc_layer_type
!=
"legacy_rel_pos"
assert
conformer_pos_enc_layer_type
!=
"legacy_rel_pos"
assert
conformer_self_attn_layer_type
!=
"legacy_rel_selfattn"
assert
conformer_self_attn_layer_type
!=
"legacy_rel_selfattn"
...
...
ernie-sat/mlm_loss.py
0 → 100644
浏览文件 @
9224659c
import
paddle
from
paddle
import
nn
class
MLMLoss
(
nn
.
Layer
):
def
__init__
(
self
,
lsm_weight
:
float
=
0.1
,
ignore_id
:
int
=-
1
,
text_masking
:
bool
=
False
):
super
().
__init__
()
if
text_masking
:
self
.
text_mlm_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_id
)
if
lsm_weight
>
50
:
self
.
l1_loss_func
=
nn
.
MSELoss
()
else
:
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
text_masking
=
text_masking
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
before_outs
:
paddle
.
Tensor
,
after_outs
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
=
None
,
text_outs
:
paddle
.
Tensor
=
None
,
text_masked_pos
:
paddle
.
Tensor
=
None
):
xs_pad
=
speech
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
if
after_outs
is
not
None
:
loss
+=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
after_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
if
self
.
text_masking
:
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
return
loss_mlm
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
浏览文件 @
9224659c
...
@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
...
@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
class
LegacyRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
pos_bias_v
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b
,
h
,
t1
,
t2
=
paddle
.
shape
(
x
)
zero_pad
=
paddle
.
zeros
((
b
,
h
,
t1
,
1
))
x_padded
=
paddle
.
concat
([
zero_pad
,
x
],
axis
=-
1
)
x_padded
=
paddle
.
reshape
(
x_padded
,
[
b
,
h
,
t2
+
1
,
t1
])
# only keep the positions from 0 to time2
x
=
paddle
.
reshape
(
x_padded
[:,
:,
1
:],
[
b
,
h
,
t1
,
t2
])
if
self
.
zero_triu
:
ones
=
paddle
.
ones
((
t1
,
t2
))
x
=
x
*
paddle
.
tril
(
ones
,
t2
-
1
)[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
mask
):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# (batch, time1, head, d_k)
q
=
paddle
.
transpose
(
q
,
[
0
,
2
,
1
,
3
])
n_batch_pos
=
paddle
.
shape
(
pos_emb
)[
0
]
p
=
paddle
.
reshape
(
self
.
linear_pos
(
pos_emb
),
[
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
])
# (batch, head, time1, d_k)
p
=
paddle
.
transpose
(
p
,
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_u
=
paddle
.
transpose
((
q
+
self
.
pos_bias_u
),
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_v
=
paddle
.
transpose
((
q
+
self
.
pos_bias_v
),
[
0
,
2
,
1
,
3
])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
paddle
.
transpose
(
p
,
[
0
,
1
,
3
,
2
]))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
# (batch, head, time1, time2)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
浏览文件 @
9224659c
...
@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
...
@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size
=
paddle
.
shape
(
self
.
pe
)
pe_size
=
paddle
.
shape
(
self
.
pe
)
pos_emb
=
self
.
pe
[:,
pe_size
[
1
]
//
2
-
T
+
1
:
pe_size
[
1
]
//
2
+
T
,
]
pos_emb
=
self
.
pe
[:,
pe_size
[
1
]
//
2
-
T
+
1
:
pe_size
[
1
]
//
2
+
T
,
]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
extend_pe
(
self
,
x
):
"""Reset the positional encodings."""
if
self
.
pe
is
not
None
:
if
paddle
.
shape
(
self
.
pe
)[
1
]
>=
paddle
.
shape
(
x
)[
1
]:
return
pe
=
paddle
.
zeros
((
paddle
.
shape
(
x
)[
1
],
self
.
d_model
))
if
self
.
reverse
:
position
=
paddle
.
arange
(
paddle
.
shape
(
x
)[
1
]
-
1
,
-
1
,
-
1.0
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
else
:
position
=
paddle
.
arange
(
0
,
paddle
.
shape
(
x
)[
1
],
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
pe
=
pe
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
extend_pe
(
x
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
:
paddle
.
shape
(
x
)[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
ernie-sat/read_text.py
浏览文件 @
9224659c
...
@@ -5,7 +5,7 @@ from typing import List
...
@@ -5,7 +5,7 @@ from typing import List
from
typing
import
Union
from
typing
import
Union
def
read_2col
umn
_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
def
read_2col_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
"""Read a text file having 2 column as dict object.
"""Read a text file having 2 column as dict object.
Examples:
Examples:
...
@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
...
@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav
key1 /some/path/a.wav
key2 /some/path/b.wav
key2 /some/path/b.wav
>>> read_2col
umn
_text('wav.scp')
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
"""
...
...
ernie-sat/sedit_arg_parser.py
浏览文件 @
9224659c
...
@@ -65,12 +65,6 @@ def parse_args():
...
@@ -65,12 +65,6 @@ def parse_args():
help
=
"mean and standard deviation used to normalize spectrogram when training voc."
help
=
"mean and standard deviation used to normalize spectrogram when training voc."
)
)
# other
# other
parser
.
add_argument
(
'--lang'
,
type
=
str
,
default
=
'en'
,
help
=
'Choose model language. zh or en'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
# parser.add_argument("--test_metadata", type=str, help="test metadata.")
# parser.add_argument("--test_metadata", type=str, help="test metadata.")
...
...
ernie-sat/utils.py
浏览文件 @
9224659c
...
@@ -32,7 +32,6 @@ model_alias = {
...
@@ -32,7 +32,6 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
}
}
def
is_chinese
(
ch
):
def
is_chinese
(
ch
):
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
return
True
return
True
...
@@ -55,12 +54,10 @@ def build_vocoder_from_file(
...
@@ -55,12 +54,10 @@ def build_vocoder_from_file(
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
def
get_voc_out
(
mel
,
target_lang
:
str
=
"chinese"
):
def
get_voc_out
(
mel
):
# vocoder
# vocoder
args
=
parse_args
()
args
=
parse_args
()
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc)
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
...
@@ -167,19 +164,23 @@ def get_voc_inference(
...
@@ -167,19 +164,23 @@ def get_voc_inference(
return
voc_inference
return
voc_inference
def
evaluate_durations
(
phns
:
List
[
str
],
def
evaluate_durations
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
target_lang
:
str
=
"chinese"
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
args
=
parse_args
()
args
=
parse_args
()
if
target_lang
==
'english'
:
if
target_lang
==
'english'
:
args
.
lang
=
'en'
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
==
'chinese'
:
elif
target_lang
==
'chinese'
:
args
.
lang
=
'zh'
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
elif
args
.
ngpu
>
0
:
...
@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
...
@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else
:
else
:
print
(
"ngpu should >= 0 !"
)
print
(
"ngpu should >= 0 !"
)
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In evaluate_durations function, target_lang is illegal..."
# Init body.
# Init body.
with
open
(
args
.
am_config
)
as
f
:
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
...
@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str],
...
@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict
=
args
.
speaker_dict
,
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
return_am
=
True
)
torch_phns
=
phns
vocab_phones
=
{}
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_phones
[
tone
]
=
int
(
id
)
vocab_size
=
len
(
vocab_phones
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_
phns
]
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids_new
=
phone_ids
phone_ids_new
=
phone_ids
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
normalized_mel
,
d_outs
,
p_outs
,
e_outs
=
am
.
inference
(
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
pre_d_outs
=
d_outs
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录