Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
b81832ce
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b81832ce
编写于
6月 10, 2022
作者:
K
Kennycao123
提交者:
GitHub
6月 10, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #825 from yt605155624/format
[ernie sat]Format ernie sat
上级
437081f7
76b654cb
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
980 addition
and
1123 deletion
+980
-1123
ernie-sat/README.md
ernie-sat/README.md
+2
-2
ernie-sat/align.py
ernie-sat/align.py
+117
-46
ernie-sat/align_mandarin.py
ernie-sat/align_mandarin.py
+0
-186
ernie-sat/dataset.py
ernie-sat/dataset.py
+209
-285
ernie-sat/inference.py
ernie-sat/inference.py
+351
-416
ernie-sat/model_paddle.py
ernie-sat/model_paddle.py
+51
-59
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
+154
-12
ernie-sat/run_clone_en_to_zh.sh
ernie-sat/run_clone_en_to_zh.sh
+2
-2
ernie-sat/run_gen_en.sh
ernie-sat/run_gen_en.sh
+2
-2
ernie-sat/run_sedit_en.sh
ernie-sat/run_sedit_en.sh
+2
-2
ernie-sat/sedit_arg_parser.py
ernie-sat/sedit_arg_parser.py
+2
-4
ernie-sat/tools/torch_pwgan.py
ernie-sat/tools/torch_pwgan.py
+1
-1
ernie-sat/utils.py
ernie-sat/utils.py
+87
-106
未找到文件。
ernie-sat/README.md
浏览文件 @
b81832ce
...
...
@@ -113,8 +113,8 @@ prompt/dev
8.
` --uid`
特定提示(prompt)语音的 id
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
10.
` --prefix`
特定音频对应的文本、音素相关文件的地址
11.
` --source_lang
uage
`
, 源语言
12.
` --target_lang
uage
`
, 目标语言
11.
` --source_lang`
, 源语言
12.
` --target_lang`
, 目标语言
13.
` --output_name`
, 合成语音名称
14.
` --task_name`
, 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15.
` --use_pt_vocoder`
, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
...
...
ernie-sat/align
_english
.py
→
ernie-sat/align.py
浏览文件 @
b81832ce
#!/usr/bin/env python
""" Usage:
align
_english
.py wavfile trsfile outwordfile outphonefile
align.py wavfile trsfile outwordfile outphonefile
"""
import
multiprocessing
as
mp
import
os
...
...
@@ -9,12 +9,45 @@ import sys
from
tqdm
import
tqdm
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR
=
'tools/aligner/english'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
HVITE
=
'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
prep_txt
(
line
,
tmpbase
,
dictfile
):
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
not
in
ds
):
unk_words
.
add
(
wrd
)
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
return
unk_words
def
prep_txt_en
(
line
:
str
,
tmpbase
,
dictfile
):
words
=
[]
...
...
@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
fw
.
close
()
def
prep_mlf
(
txt
,
tmpbase
):
def
prep_mlf
(
txt
:
str
,
tmpbase
:
str
):
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
fwid
.
write
(
'#!MLF!#
\n
'
)
...
...
@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
fwid
.
write
(
'.
\n
'
)
def
gen_res
(
tmpbase
,
outfile1
,
outfile2
):
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
"/"
)[
-
1
]
def
alignment
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt_en
(
text
,
tmpbase
,
MODEL_DIR_EN
+
'/dict'
)
except
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR_EN
+
'/16000/macros -H '
+
MODEL_DIR_EN
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
...
...
@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times1
=
[]
times2
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
if
(
len
(
lines
[
i
].
split
())
>=
4
)
and
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()
[
1
]):
phn
=
lines
[
i
].
split
()
[
2
]
pst
=
(
int
(
lines
[
i
].
split
()
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
lines
[
i
].
split
()
[
1
])
/
1000
+
125
)
/
10000
splited_line
=
lines
[
i
].
strip
().
split
()
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
if
(
len
(
lines
[
i
].
split
())
==
5
):
if
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
wrd
=
lines
[
i
].
split
()[
-
1
].
strip
()
st
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
j
=
i
+
1
while
(
lines
[
j
]
!=
'.
\n
'
)
and
(
len
(
lines
[
j
].
split
())
!=
5
):
j
+=
1
en
=
(
int
(
lines
[
j
-
1
].
split
()[
1
])
/
1000
+
125
)
/
10000
times1
.
append
([
wrd
,
st
,
en
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
word2phns
[
current_word
]
=
phn
index
+=
1
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
with
open
(
outfile1
,
'w'
)
as
fwid
:
for
item
in
times1
:
if
(
item
[
0
]
==
'sp'
):
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' SIL
\n
'
)
else
:
wrd
=
words
.
pop
()
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
wrd
+
'
\n
'
)
if
words
:
print
(
'not matched::'
+
alignfile
)
sys
.
exit
(
1
)
with
open
(
outfile2
,
'w'
)
as
fwid
:
for
item
in
times2
:
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
item
[
0
]
+
'
\n
'
)
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
"/"
)[
-
1
]
return
times2
,
word2phns
def
alignment
(
wav_path
,
text_string
):
def
alignment
_zh
(
wav_path
,
text_string
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt
(
text_string
,
tmpbase
,
MODEL_DIR
+
'/dict'
)
unk_words
=
prep_txt_zh
(
text_string
,
tmpbase
,
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
print
(
'prep_txt error!'
)
return
None
...
...
@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
+
'/16000/config '
+
tmpbase
+
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
print
(
'HCopy error!'
)
...
...
@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR
+
'/16000/macros -H '
+
MODEL_DIR
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR
+
'/monophones '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR
_ZH
+
'/16000/macros -H '
+
MODEL_DIR_ZH
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
MODEL_DIR_ZH
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
print
(
'HVite error!'
)
return
None
...
...
@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
word2phns
=
{}
...
...
ernie-sat/align_mandarin.py
已删除
100755 → 0
浏览文件 @
437081f7
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import
multiprocessing
as
mp
import
os
import
sys
from
tqdm
import
tqdm
MODEL_DIR
=
'tools/aligner/mandarin'
HVITE
=
'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
prep_txt
(
line
,
tmpbase
,
dictfile
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
not
in
ds
):
unk_words
.
add
(
wrd
)
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
return
unk_words
def
prep_mlf
(
txt
,
tmpbase
):
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
fwid
.
write
(
'#!MLF!#
\n
'
)
fwid
.
write
(
'"'
+
tmpbase
+
'.lab"
\n
'
)
fwid
.
write
(
'sp
\n
'
)
wrds
=
txt
.
split
()
for
wrd
in
wrds
:
fwid
.
write
(
wrd
.
upper
()
+
'
\n
'
)
fwid
.
write
(
'sp
\n
'
)
fwid
.
write
(
'.
\n
'
)
def
gen_res
(
tmpbase
,
outfile1
,
outfile2
):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times1
=
[]
times2
=
[]
while
(
i
<
len
(
lines
)):
if
(
len
(
lines
[
i
].
split
())
>=
4
)
and
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
phn
=
lines
[
i
].
split
()[
2
]
pst
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
lines
[
i
].
split
()[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
if
(
len
(
lines
[
i
].
split
())
==
5
):
if
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
wrd
=
lines
[
i
].
split
()[
-
1
].
strip
()
st
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
j
=
i
+
1
while
(
lines
[
j
]
!=
'.
\n
'
)
and
(
len
(
lines
[
j
].
split
())
!=
5
):
j
+=
1
en
=
(
int
(
lines
[
j
-
1
].
split
()[
1
])
/
1000
+
125
)
/
10000
times1
.
append
([
wrd
,
st
,
en
])
i
+=
1
with
open
(
outfile1
,
'w'
)
as
fwid
:
for
item
in
times1
:
if
(
item
[
0
]
==
'sp'
):
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' SIL
\n
'
)
else
:
wrd
=
words
.
pop
()
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
wrd
+
'
\n
'
)
if
words
:
print
(
'not matched::'
+
alignfile
)
sys
.
exit
(
1
)
with
open
(
outfile2
,
'w'
)
as
fwid
:
for
item
in
times2
:
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
item
[
0
]
+
'
\n
'
)
def
alignment_zh
(
wav_path
,
text_string
):
tmpbase
=
'/tmp/'
+
os
.
environ
[
'USER'
]
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
unk_words
=
prep_txt
(
text_string
,
tmpbase
,
MODEL_DIR
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR
+
'/16000/macros -H '
+
MODEL_DIR
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
MODEL_DIR
+
'/dict '
+
MODEL_DIR
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
splited_line
=
lines
[
i
].
strip
().
split
()
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
word2phns
[
current_word
]
=
phn
index
+=
1
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
ernie-sat/dataset.py
浏览文件 @
b81832ce
...
...
@@ -4,37 +4,180 @@ import numpy as np
import
paddle
def
pad_list
(
xs
,
pad_value
):
"""Perform padding for the list of tensors.
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
float
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
unmasked_phn_idxs
=
list
(
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
masked_text_idxs
=
unmasked_phn_idxs
[:
text_mask_num_lower
]
text_masked_pos
[
idx
][
masked_text_idxs
]
=
1
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
non_eos_text_mask
=
paddle
.
reshape
(
text_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
text_masked_pos
=
text_masked_pos
*
non_eos_text_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
return
masked_pos
,
text_masked_pos
,
y_masks
def
get_seg_pos_reduce_duration
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
,
masked_pos
:
paddle
.
Tensor
,
feats_lens
:
paddle
.
Tensor
,
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
text_seg_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
reordered_idx
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
durations
=
paddle
.
ones
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
max_reduced_length
=
0
if
not
sega_emb
:
return
speech_pad
,
masked_pos
,
speech_seg_pos
,
text_seg_pos
,
durations
for
idx
in
range
(
bz
):
first_idx
=
[]
last_idx
=
[]
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
if
j
==
0
:
if
paddle
.
sum
(
masked_pos
[
idx
][
0
:
s
])
==
0
:
first_idx
.
extend
(
range
(
0
,
s
))
else
:
first_idx
.
extend
([
0
])
last_idx
.
extend
(
range
(
1
,
s
))
if
paddle
.
sum
(
masked_pos
[
idx
][
s
:
e
])
==
0
:
first_idx
.
extend
(
range
(
s
,
e
))
else
:
first_idx
.
extend
([
s
])
last_idx
.
extend
(
range
(
s
+
1
,
e
))
durations
[
idx
][
s
]
=
e
-
s
speech_seg_pos
[
idx
][
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
][
j
]
=
j
+
1
max_reduced_length
=
max
(
len
(
first_idx
)
+
feats_lens
[
idx
]
-
e
,
max_reduced_length
)
first_idx
.
extend
(
range
(
e
,
speech_len
))
reordered_idx
[
idx
]
=
paddle
.
to_tensor
(
(
first_idx
+
last_idx
),
dtype
=
align_start_lens
.
dtype
)
feats_lens
[
idx
]
=
len
(
first_idx
)
reordered_idx
=
reordered_idx
[:,
:
max_reduced_length
]
return
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
,
mean_phn_span
:
float
):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
length: an int32 scalar (length of the incoming token sequence)
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
a boolean tensor with shape [length]
"""
n_batch
=
len
(
xs
)
max_len
=
max
(
paddle
.
shape
(
x
)[
0
]
for
x
in
xs
)
pad
=
paddle
.
full
((
n_batch
,
max_len
),
pad_value
,
dtype
=
xs
[
0
].
dtype
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
paddle
.
shape
(
xs
[
i
])[
0
]]
=
xs
[
i
]
return
pad
orig_length
=
length
num_noise_tokens
=
int
(
np
.
round
(
length
*
mlm_prob
))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens
=
min
(
max
(
num_noise_tokens
,
1
),
length
-
1
)
num_noise_spans
=
int
(
np
.
round
(
num_noise_tokens
/
mean_phn_span
))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans
=
max
(
num_noise_spans
,
1
)
num_nonnoise_tokens
=
length
-
num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def
_random_seg
(
num_items
,
num_segs
):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segs: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segs] containing positive integers that add
up to num_items
"""
mask_idxs
=
np
.
arange
(
num_items
-
1
)
<
(
num_segs
-
1
)
np
.
random
.
shuffle
(
mask_idxs
)
first_in_seg
=
np
.
pad
(
mask_idxs
,
[[
1
,
0
]])
segment_id
=
np
.
cumsum
(
first_in_seg
)
# count length of sub segments assuming that list is sorted
_
,
segment_length
=
np
.
unique
(
segment_id
,
return_counts
=
True
)
return
segment_length
noise_span_lens
=
_random_seg
(
num_noise_tokens
,
num_noise_spans
)
nonnoise_span_lens
=
_random_seg
(
num_nonnoise_tokens
,
num_noise_spans
)
interleaved_span_lens
=
np
.
reshape
(
np
.
stack
([
nonnoise_span_lens
,
noise_span_lens
],
axis
=
1
),
[
num_noise_spans
*
2
])
span_starts
=
np
.
cumsum
(
interleaved_span_lens
)[:
-
1
]
span_start_indicator
=
np
.
zeros
((
length
,
),
dtype
=
np
.
int8
)
span_start_indicator
[
span_starts
]
=
True
span_num
=
np
.
cumsum
(
span_start_indicator
)
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
return
is_noise
[:
orig_length
]
def
pad_to_longformer_att_window
(
text
:
paddle
.
Tensor
,
max_len
:
int
,
max_tlen
:
int
,
attention_window
:
int
=
0
):
def
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
...
...
@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
return
text_pad
,
max_tlen
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
if
not
isinstance
(
lengths
,
list
):
lengths
=
list
(
lengths
)
bs
=
int
(
len
(
lengths
))
if
xs
is
None
:
maxlen
=
int
(
max
(
lengths
))
else
:
maxlen
=
paddle
.
shape
(
xs
)[
length_dim
]
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
paddle
.
expand
(
paddle
.
unsqueeze
(
seq_range
,
0
),
(
bs
,
maxlen
))
seq_length_expand
=
paddle
.
unsqueeze
(
paddle
.
to_tensor
(
lengths
),
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
if
xs
is
not
None
:
assert
paddle
.
shape
(
xs
)[
0
]
==
bs
,
(
paddle
.
shape
(
xs
)[
0
],
bs
)
if
length_dim
<
0
:
length_dim
=
len
(
paddle
.
shape
(
xs
))
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
len
(
paddle
.
shape
(
xs
))))
mask
=
paddle
.
expand
(
mask
[
ind
],
paddle
.
shape
(
xs
))
return
mask
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return
~
make_pad_mask
(
lengths
,
xs
,
length_dim
)
def
phones_masking
(
xs_pad
,
src_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
=
None
):
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
int
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
mask_num_lower
=
math
.
ceil
(
sent_len
*
mlm_prob
)
masked_position
=
np
.
zeros
((
bz
,
sent_len
))
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
# y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype)
# tril_masks = torch.tril(y_masks)
if
mlm_prob
==
1.0
:
masked_position
+=
1
# y_masks = tril_masks
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_i
ndice
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
ition
[:,
masked_phn_indice
s
]
=
1
masked_phn_i
dx
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idx
s
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_boundary
is
not
None
:
for
s
,
e
in
zip
(
span_boundary
[
idx
][::
2
],
span_boundary
[
idx
][
1
::
2
]):
masked_position
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_len
gths
[
idx
].
item
()
length
=
align_start_len
s
[
idx
]
if
length
<
2
:
continue
masked_phn_i
ndice
s
=
random_spans_noise_mask
(
masked_phn_i
dx
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_i
ndice
s
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_i
ndice
s
].
tolist
()
masked_start
=
align_start
[
idx
][
masked_phn_i
dx
s
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_i
dx
s
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_position
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
non_eos_mask
=
np
.
array
(
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
]).
float
().
cpu
())
masked_position
=
masked_position
*
non_eos_mask
# y_masks = src_mask & y_masks.bool()
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
paddle
.
cast
(
paddle
.
to_tensor
(
masked_position
),
paddle
.
bool
)
,
y_masks
return
masked_pos
,
y_masks
def
get_segment_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
sega_emb
):
bz
,
speech_len
,
_
=
speech_pad
.
size
()
_
,
text_len
=
text_pad
.
size
()
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
# text_segment_pos = paddle.zeros_like(text_pad)
# speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype)
text_segment_pos
=
np
.
zeros
((
bz
,
text_len
)).
astype
(
'int64'
)
speech_segment_pos
=
np
.
zeros
((
bz
,
speech_len
)).
astype
(
'int64'
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
if
not
sega_emb
:
text_segment_pos
=
paddle
.
to_tensor
(
text_segment_pos
)
speech_segment_pos
=
paddle
.
to_tensor
(
speech_segment_pos
)
return
speech_segment_pos
,
text_segment_pos
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
align_length
=
align_start_len
gths
[
idx
].
item
()
align_length
=
align_start_len
s
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
].
item
(),
align_end
[
idx
][
j
].
item
()
speech_segment_pos
[
idx
][
s
:
e
]
=
j
+
1
text_segment_pos
[
idx
][
j
]
=
j
+
1
text_segment_pos
=
paddle
.
to_tensor
(
text_segment_pos
)
speech_segment_pos
=
paddle
.
to_tensor
(
speech_segment_pos
)
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
speech_seg
ment_pos
,
text_segment
_pos
return
speech_seg
_pos
,
text_seg
_pos
ernie-sat/inference.py
浏览文件 @
b81832ce
#!/usr/bin/env python3
import
argparse
import
math
import
os
import
pickle
import
random
import
string
import
sys
from
pathlib
import
Path
from
typing
import
Collection
from
typing
import
Dict
...
...
@@ -18,17 +14,17 @@ import numpy as np
import
paddle
import
soundfile
as
sf
import
torch
from
paddle
import
nn
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
align_english
import
alignment
from
align_mandarin
import
alignment_zh
from
dataset
import
get_segment_pos
from
dataset
import
make_non_pad_mask
from
dataset
import
make_pad_mask
from
dataset
import
pad_list
from
align
import
alignment
from
align
import
alignment_zh
from
dataset
import
get_seg_pos
from
dataset
import
get_seg_pos_reduce_duration
from
dataset
import
pad_to_longformer_att_window
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
model_paddle
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2column_text
...
...
@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file
from
utils
import
evaluate_durations
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
utils
import
sentence2phns
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
...
@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
def
plot_mel_and_vocode_wav
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
use_pt_vocoder
,
duration_preditor_path
,
sid
=
None
,
non_autoreg
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
=
get_mlm_output
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
def
plot_mel_and_vocode_wav
(
uid
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"conformer"
,
full_origin_str
:
str
=
""
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
use_pt_vocoder
:
bool
=
False
,
sid
:
str
=
None
,
non_autoreg
:
bool
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
)
masked_feat
=
output_feat
[
new_span_boundary
[
0
]:
new_span_boundary
[
1
]].
detach
().
float
().
cpu
().
numpy
()
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
if
use_pt_vocoder
:
output_feat
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
output_feat
=
output_feat
.
cpu
().
numpy
()
output_feat
=
torch
.
tensor
(
output_feat
,
dtype
=
torch
.
float
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
replaced_wav
=
vocoder
(
output_feat
).
detach
().
float
().
data
.
cpu
().
numpy
()
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
else
:
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav
=
get_voc_out
(
output_feat_np
,
target_language
)
replaced_wav
=
get_voc_out
(
output_feat
,
target_lang
)
elif
target_language
==
'chinese'
:
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_language
)
elif
target_lang
==
'chinese'
:
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_lang
)
old_time_b
oundary
=
[
hop_length
*
x
for
x
in
old_span_boundar
y
]
new_time_b
oundary
=
[
hop_length
*
x
for
x
in
new_span_boundar
y
]
old_time_b
dy
=
[
hop_length
*
x
for
x
in
old_span_bd
y
]
new_time_b
dy
=
[
hop_length
*
x
for
x
in
new_span_bd
y
]
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
wav_org_replaced_paddle_voc
=
np
.
concatenate
([
wav_org
[:
old_time_b
oundar
y
[
0
]],
replaced_wav
[
new_time_b
oundary
[
0
]:
new_time_boundar
y
[
1
]],
wav_org
[
old_time_b
oundar
y
[
1
]:]
wav_org
[:
old_time_b
d
y
[
0
]],
replaced_wav
[
new_time_b
dy
[
0
]:
new_time_bd
y
[
1
]],
wav_org
[
old_time_b
d
y
[
1
]:]
])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_paddle_voc
}
elif
target_lang
uage
==
'chinese'
:
elif
target_lang
==
'chinese'
:
wav_org_replaced_only_mask_fst2_voc
=
np
.
concatenate
([
wav_org
[:
old_time_b
oundar
y
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[
old_time_b
oundar
y
[
1
]:]
wav_org
[:
old_time_b
d
y
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[
old_time_b
d
y
[
1
]:]
])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_only_mask_fst2_voc
,
}
return
data_dict
,
old_span_b
oundar
y
return
data_dict
,
old_span_b
d
y
def
get_unk_phns
(
word_str
):
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
...
...
@@ -160,9 +148,8 @@ def get_unk_phns(word_str):
return
phns
def
words2phns
(
line
):
def
words2phns
(
line
:
str
):
dictfile
=
MODEL_DIR_EN
+
'/dict'
tmpbase
=
'/tmp/tp.'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
...
...
@@ -200,9 +187,8 @@ def words2phns(line):
return
phns
,
wrd2phns
def
words2phns_zh
(
line
):
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
tmpbase
=
'/tmp/tp.'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
...
...
@@ -242,7 +228,7 @@ def words2phns_zh(line):
return
phns
,
wrd2phns
def
load_vocoder
(
vocoder_tag
=
"vctk_parallel_wavegan.v1.long"
):
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
vocoder_config
=
Path
(
vocoder_file
).
parent
/
"config.yml"
...
...
@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
return
vocoder
def
load_model
(
model_name
):
def
load_model
(
model_name
:
str
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
mlm_model
,
args
=
build_model_from_file
(
...
...
@@ -258,7 +244,7 @@ def load_model(model_name):
return
mlm_model
,
args
def
read_data
(
uid
,
prefix
):
def
read_data
(
uid
:
str
,
prefix
:
str
):
mfa_text
=
read_2column_text
(
prefix
+
'/text'
)[
uid
]
mfa_wav_path
=
read_2column_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
'mnt'
not
in
mfa_wav_path
:
...
...
@@ -266,7 +252,7 @@ def read_data(uid, prefix):
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
,
prefix
):
def
get_align_data
(
uid
:
str
,
prefix
:
str
):
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2column_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
...
...
@@ -277,43 +263,45 @@ def get_align_data(uid, prefix):
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
def
get_masked_mel_boundary
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_tobe_replaced
):
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
fs
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
align_start
=
paddle
.
floor
(
fs
*
align_start
/
hop_length
).
int
()
align_end
=
paddle
.
floor
(
fs
*
align_end
/
hop_length
).
int
()
if
span_to
be_replaced
[
0
]
>=
len
(
mfa_start
):
span_b
oundar
y
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()[
-
1
]]
if
span_to
_repl
[
0
]
>=
len
(
mfa_start
):
span_b
d
y
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()[
-
1
]]
else
:
span_b
oundar
y
=
[
align_start
[
0
].
tolist
()[
span_to
be_replaced
[
0
]],
align_end
[
0
].
tolist
()[
span_to
be_replaced
[
1
]
-
1
]
span_b
d
y
=
[
align_start
[
0
].
tolist
()[
span_to
_repl
[
0
]],
align_end
[
0
].
tolist
()[
span_to
_repl
[
1
]
-
1
]
]
return
span_b
oundar
y
return
span_b
d
y
def
recover_dict
(
word2phns
,
tp_word2phns
):
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]
):
dic
=
{}
need_del_key
=
[]
exist_i
nde
x
=
[]
keys_to_del
=
[]
exist_i
d
x
=
[]
sp_count
=
0
add_sp_count
=
0
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
exist_i
nde
x
.
append
(
int
(
idx
))
exist_i
d
x
.
append
(
int
(
idx
))
else
:
need_del_key
.
append
(
key
)
keys_to_del
.
append
(
key
)
for
key
in
need_del_key
:
for
key
in
keys_to_del
:
del
word2phns
[
key
]
cur_id
=
0
for
key
in
tp_word2phns
.
keys
():
# print("debug: ",key)
if
cur_id
in
exist_index
:
if
cur_id
in
exist_idx
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
cur_id
+=
1
add_sp_count
+=
1
...
...
@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns):
return
dic
def
get_phns_and_spans
(
wav_path
,
old_str
,
new_str
,
source_language
,
clone_target_language
):
def
get_phns_and_spans
(
wav_path
:
str
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
append_new_str
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
if
source_lang
uage
==
"english"
:
if
source_lang
==
"english"
:
times2
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
uage
==
"chinese"
:
elif
source_lang
==
"chinese"
:
times2
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
...
...
@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
assert
source_lang
uage
==
"chinese"
or
source_language
==
"english"
,
"source_language
is wrong..."
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
"source_lang
is wrong..."
for
item
in
times2
:
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
old_phns
.
append
(
item
[
0
])
if
append_new_str
and
(
source_lang
uage
!=
clone_target_language
):
if
append_new_str
and
(
source_lang
!=
target_lang
):
is_cross_lingual_clone
=
True
else
:
is_cross_lingual_clone
=
False
...
...
@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_str_origin
=
new_str
[:
len
(
old_str
)]
new_str_append
=
new_str
[
len
(
old_str
):]
if
clone_target_language
==
"chinese"
:
if
target_lang
==
"chinese"
:
new_phns_origin
,
new_origin_word2phns
=
words2phns
(
new_str_origin
)
new_phns_append
,
temp_new_append_word2phns
=
words2phns_zh
(
new_str_append
)
elif
clone_target_language
==
"english"
:
elif
target_lang
==
"english"
:
# 原始句子
new_phns_origin
,
new_origin_word2phns
=
words2phns_zh
(
new_str_origin
)
# 原始句子
new_str_origin
)
# clone句子
new_phns_append
,
temp_new_append_word2phns
=
words2phns
(
new_str_append
)
# clone句子
new_str_append
)
else
:
assert
clone_target_language
==
"chinese"
or
clone_target_language
==
"english"
,
"cloning is not support for this language, please check it."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
new_phns
=
new_phns_origin
+
new_phns_append
...
...
@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_append_word2phns
.
items
()))
else
:
if
source_lang
uage
==
clone_target_language
and
clone_target_language
==
"english"
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
new_phns
,
new_word2phns
=
words2phns
(
new_str
)
elif
source_lang
uage
==
clone_target_language
and
clone_target_language
==
"chinese"
:
elif
source_lang
==
target_lang
and
target_lang
==
"chinese"
:
new_phns
,
new_word2phns
=
words2phns_zh
(
new_str
)
else
:
assert
source_language
==
clone_target_language
,
"source language is not same with target language..."
assert
source_lang
==
target_lang
,
\
"source language is not same with target language..."
span_to
be_replaced
=
[
0
,
len
(
old_phns
)
-
1
]
span_to
be_adde
d
=
[
0
,
len
(
new_phns
)
-
1
]
left_i
nde
x
=
0
span_to
_repl
=
[
0
,
len
(
old_phns
)
-
1
]
span_to
_ad
d
=
[
0
,
len
(
new_phns
)
-
1
]
left_i
d
x
=
0
new_phns_left
=
[]
sp_count
=
0
# find the left different index
...
...
@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
else
:
idx
=
str
(
int
(
idx
)
-
sp_count
)
if
idx
+
'_'
+
wrd
in
new_word2phns
:
left_i
nde
x
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
left_i
d
x
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_left
.
extend
(
word2phns
[
key
].
split
())
else
:
span_to
be_replaced
[
0
]
=
len
(
new_phns_left
)
span_to
be_adde
d
[
0
]
=
len
(
new_phns_left
)
span_to
_repl
[
0
]
=
len
(
new_phns_left
)
span_to
_ad
d
[
0
]
=
len
(
new_phns_left
)
break
# reverse word2phns and new_word2phns
right_i
nde
x
=
0
right_i
d
x
=
0
new_phns_right
=
[]
sp_count
=
0
word2phns_max_i
nde
x
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_word2phns_max_i
nde
x
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_phns_mid
dle
=
[]
word2phns_max_i
d
x
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_word2phns_max_i
d
x
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_phns_mid
=
[]
if
append_new_str
:
new_phns_right
=
[]
new_phns_mid
dle
=
new_phns
[
left_inde
x
:]
span_to
be_replaced
[
0
]
=
len
(
new_phns_left
)
span_to
be_adde
d
[
0
]
=
len
(
new_phns_left
)
span_to
be_added
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_middle
)
span_to
be_replaced
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_mid
=
new_phns
[
left_id
x
:]
span_to
_repl
[
0
]
=
len
(
new_phns_left
)
span_to
_ad
d
[
0
]
=
len
(
new_phns_left
)
span_to
_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to
_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
...
...
@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
sp_count
+=
1
new_phns_right
=
[
'sp'
]
+
new_phns_right
else
:
idx
=
str
(
new_word2phns_max_i
ndex
-
(
word2phns_max_index
-
int
(
idx
)
-
sp_count
))
idx
=
str
(
new_word2phns_max_i
dx
-
(
word2phns_max_idx
-
int
(
idx
)
-
sp_count
))
if
idx
+
'_'
+
wrd
in
new_word2phns
:
right_i
nde
x
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
right_i
d
x
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_right
=
word2phns
[
key
].
split
()
+
new_phns_right
else
:
span_tobe_replaced
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_middle
=
new_phns
[
left_index
:
right_index
]
span_tobe_added
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_middle
)
if
len
(
new_phns_middle
)
==
0
:
span_tobe_added
[
1
]
=
min
(
span_tobe_added
[
1
]
+
1
,
len
(
new_phns
))
span_tobe_added
[
0
]
=
max
(
0
,
span_tobe_added
[
0
]
-
1
)
span_tobe_replaced
[
0
]
=
max
(
0
,
span_tobe_replaced
[
0
]
-
1
)
span_tobe_replaced
[
1
]
=
min
(
span_tobe_replaced
[
1
]
+
1
,
len
(
old_phns
))
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_mid
=
new_phns
[
left_idx
:
right_idx
]
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
if
len
(
new_phns_mid
)
==
0
:
span_to_add
[
1
]
=
min
(
span_to_add
[
1
]
+
1
,
len
(
new_phns
))
span_to_add
[
0
]
=
max
(
0
,
span_to_add
[
0
]
-
1
)
span_to_repl
[
0
]
=
max
(
0
,
span_to_repl
[
0
]
-
1
)
span_to_repl
[
1
]
=
min
(
span_to_repl
[
1
]
+
1
,
len
(
old_phns
))
break
new_phns
=
new_phns_left
+
new_phns_mid
dle
+
new_phns_right
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to
be_replaced
,
span_tobe_adde
d
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to
_repl
,
span_to_ad
d
def
duration_adjust_factor
(
original_dur
,
pred_dur
,
phns
):
def
duration_adjust_factor
(
original_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
accumulate
=
0
factor_list
=
[]
for
ori
,
pred
,
phn
in
zip
(
original_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
...
...
@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
return
np
.
average
(
factor_list
[
length
:
-
length
])
def
prepare_features_with_duration
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
old_str
,
new_str
,
wav_path
,
duration_preditor_path
,
sid
=
None
,
mask_reconstruct
=
False
,
duration_adjust
=
True
,
start_end_sp
=
False
,
def
prepare_features_with_duration
(
uid
:
str
,
prefix
:
str
,
wav_path
:
str
,
mlm_model
:
nn
.
Layer
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
wav_org
,
rate
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_tobe_replaced
,
span_tobe_added
=
get_phns_and_spans
(
wav_path
,
old_str
,
new_str
,
source_language
,
target_language
)
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
)
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
if
target_language
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_language
=
target_language
)
if
target_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
target_lang
)
elif
target_lang
uage
==
"chinese"
:
elif
target_lang
==
"chinese"
:
if
source_lang
uage
==
"english"
:
if
source_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
uage
=
source_language
)
old_phns
,
target_lang
=
source_lang
)
elif
source_lang
uage
==
"chinese"
:
elif
source_lang
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
uage
=
source_language
)
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"calculate duration_predict is not support for this language..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"calculate duration_predict is not support for this language..."
original_old_durations
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
span_to
be_added
=
span_tobe_replaced
span_to
_add
=
span_to_repl
d_factor_left
=
duration_adjust_factor
(
original_old_durations
[:
span_tobe_replaced
[
0
]],
old_durations
[:
span_tobe_replaced
[
0
]],
old_phns
[:
span_tobe_replaced
[
0
]])
original_old_durations
[:
span_to_repl
[
0
]],
old_durations
[:
span_to_repl
[
0
]],
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
duration_adjust_factor
(
original_old_durations
[
span_tobe_replaced
[
1
]:],
old_durations
[
span_tobe_replaced
[
1
]:],
old_phns
[
span_tobe_replaced
[
1
]:])
original_old_durations
[
span_to_repl
[
1
]:],
old_durations
[
span_to_repl
[
1
]:],
old_phns
[
span_to_repl
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
old_durations
]
else
:
if
duration_adjust
:
d_factor
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
d_factor_paddle
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
uage
==
"english"
:
if
target_lang
==
"english"
:
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
uage
=
target_language
)
new_phns
,
target_lang
=
target_lang
)
elif
target_lang
uage
==
"chinese"
:
elif
target_lang
==
"chinese"
:
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
uage
=
target_language
)
new_phns
,
target_lang
=
target_lang
)
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
new_durations
]
if
span_tobe_replaced
[
0
]
<
len
(
old_phns
)
and
old_phns
[
span_tobe_replaced
[
0
]]
==
new_phns
[
span_tobe_added
[
0
]]:
new_durations_adjusted
[
span_tobe_added
[
0
]]
=
original_old_durations
[
span_tobe_replaced
[
0
]]
if
span_tobe_replaced
[
1
]
<
len
(
old_phns
)
and
span_tobe_added
[
1
]
<
len
(
new_phns
):
if
old_phns
[
span_tobe_replaced
[
1
]]
==
new_phns
[
span_tobe_added
[
1
]]:
new_durations_adjusted
[
span_tobe_added
[
1
]]
=
original_old_durations
[
span_tobe_replaced
[
1
]]
if
span_to_repl
[
0
]
<
len
(
old_phns
)
and
old_phns
[
span_to_repl
[
0
]]
==
new_phns
[
span_to_add
[
0
]]:
new_durations_adjusted
[
span_to_add
[
0
]]
=
original_old_durations
[
span_to_repl
[
0
]]
if
span_to_repl
[
1
]
<
len
(
old_phns
)
and
span_to_add
[
1
]
<
len
(
new_phns
):
if
old_phns
[
span_to_repl
[
1
]]
==
new_phns
[
span_to_add
[
1
]]:
new_durations_adjusted
[
span_to_add
[
1
]]
=
original_old_durations
[
span_to_repl
[
1
]]
new_span_duration_sum
=
sum
(
new_durations_adjusted
[
span_to
be_added
[
0
]:
span_tobe_adde
d
[
1
]])
new_durations_adjusted
[
span_to
_add
[
0
]:
span_to_ad
d
[
1
]])
old_span_duration_sum
=
sum
(
original_old_durations
[
span_to
be_replaced
[
0
]:
span_tobe_replaced
[
1
]])
original_old_durations
[
span_to
_repl
[
0
]:
span_to_repl
[
1
]])
duration_offset
=
new_span_duration_sum
-
old_span_duration_sum
new_mfa_start
=
mfa_start
[:
span_to
be_replaced
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to
be_replaced
[
0
]]
for
i
in
new_durations_adjusted
[
span_to
be_added
[
0
]:
span_tobe_adde
d
[
1
]]:
new_mfa_start
=
mfa_start
[:
span_to
_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to
_repl
[
0
]]
for
i
in
new_durations_adjusted
[
span_to
_add
[
0
]:
span_to_ad
d
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
i
+
duration_offset
for
i
in
mfa_start
[
span_tobe_replaced
[
1
]:]
]
new_mfa_end
+=
[
i
+
duration_offset
for
i
in
mfa_end
[
span_tobe_replaced
[
1
]:]
]
new_mfa_start
+=
[
i
+
duration_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
duration_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
if
span_to
be_replaced
[
0
]
>=
len
(
mfa_start
):
left_i
nde
x
=
len
(
wav_org
)
right_i
ndex
=
left_inde
x
if
span_to
_repl
[
0
]
>=
len
(
mfa_start
):
left_i
d
x
=
len
(
wav_org
)
right_i
dx
=
left_id
x
else
:
left_i
ndex
=
int
(
np
.
floor
(
mfa_start
[
span_tobe_replaced
[
0
]]
*
fs
))
right_i
ndex
=
int
(
np
.
ceil
(
mfa_end
[
span_tobe_replaced
[
1
]
-
1
]
*
fs
))
left_i
dx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_i
dx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
new_blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_duration_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
new_wav_org
=
np
.
concatenate
(
[
wav_org
[:
left_i
ndex
],
new_blank_wav
,
wav_org
[
right_inde
x
:]])
[
wav_org
[:
left_i
dx
],
new_blank_wav
,
wav_org
[
right_id
x
:]])
# 4. get old and new mel span to be mask
old_span_boundary
=
get_masked_mel_boundary
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_tobe_replaced
)
# [92, 92]
new_span_boundary
=
get_masked_mel_boundary
(
new_mfa_start
,
new_mfa_end
,
fs
,
hop_length
,
span_tobe_added
)
# [92, 174]
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_boundary
,
new_span_boundary
def
prepare_features
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
# [92, 92]
old_span_bdy
=
get_masked_mel_bdy
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_to_repl
)
# [92, 174]
new_span_bdy
=
get_masked_mel_bdy
(
new_mfa_start
,
new_mfa_end
,
fs
,
hop_length
,
span_to_add
)
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prepare_features
(
uid
:
str
,
mlm_model
:
nn
.
Layer
,
processor
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
None
,
duration_adjust
=
True
,
start_end_sp
=
False
,
mask_reconstruct
=
False
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
train_args
=
None
):
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_boundary
,
new_span_boundary
=
prepare_features_with_duration
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
old_str
,
new_str
,
wav_path
,
duration_preditor_path
,
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prepare_features_with_duration
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
old_str
=
old_str
,
new_str
=
new_str
,
wav_path
=
wav_path
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
train_args
=
train_args
)
speech
=
np
.
array
(
wav_org
,
dtype
=
np
.
float32
)
speech
=
wav_org
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
# print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_boundary
=
np
.
array
(
new_span_boundary
)
span_bdy
=
np
.
array
(
new_span_bdy
)
batch
=
[(
'1'
,
{
"speech"
:
speech
,
"align_start"
:
align_start
,
"align_end"
:
align_end
,
"text"
:
text
,
"span_b
oundary"
:
span_boundar
y
"span_b
dy"
:
span_bd
y
})]
return
batch
,
old_span_b
oundary
,
new_span_boundar
y
return
batch
,
old_span_b
dy
,
new_span_bd
y
def
decode_with_model
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
def
decode_with_model
(
uid
:
str
,
mlm_model
:
nn
.
Layer
,
processor
,
collate_fn
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
None
,
decoder
=
False
,
use_teacher_forcing
=
False
,
duration_adjust
=
True
,
start_end_sp
=
False
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
batch
,
old_span_boundary
,
new_span_boundary
=
prepare_features
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
,
batch
,
old_span_bdy
,
new_span_bdy
=
prepare_features
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
processor
=
processor
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos
ition
'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos
ition
'
)
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
for
k
,
v
in
feats
.
items
():
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
rtn
=
mlm_model
.
inference
(
**
feats
,
span_boundary
=
new_span_boundary
,
use_teacher_forcing
=
use_teacher_forcing
)
**
feats
,
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
output
=
rtn
[
'feat_gen'
]
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
...
...
@@ -731,12 +706,9 @@ def decode_with_model(uid,
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
axis
=
0
).
cpu
()
wav_org
,
rate
=
librosa
.
load
(
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
origin_speech
=
paddle
.
to_tensor
(
np
.
array
(
wav_org
,
dtype
=
np
.
float32
)).
unsqueeze
(
0
)
speech_lengths
=
paddle
.
to_tensor
(
len
(
wav_org
)).
unsqueeze
(
0
)
return
wav_org
,
None
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
return
wav_org
,
None
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
class
MLMCollateFn
:
...
...
@@ -800,33 +772,15 @@ def mlm_collate_fn(
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_len
gth
s"
)
for
k
in
data
[
0
]),
f
"*_len
gth
s is reserved:
{
list
(
data
[
0
])
}
"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
...
...
@@ -846,37 +800,35 @@ def mlm_collate_fn(
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
long
)
output
[
key
+
"_len
gth
s"
]
=
lens
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
# print('out shape', paddle.shape(feats))
feats_lengths
=
paddle
.
shape
(
feats
)[
0
]
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
batch_size
=
paddle
.
shape
(
feats
)[
0
]
if
'text'
not
in
output
:
text
=
paddle
.
zeros
_like
(
feats_lengths
.
unsqueeze
(
-
1
))
-
2
text_len
gths
=
paddle
.
zeros_like
(
feats_lengths
)
+
1
text
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
.
unsqueeze
(
-
1
)
))
-
2
text_len
s
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
)
)
+
1
max_tlen
=
1
align_start
=
paddle
.
zeros_like
(
text
)
align_end
=
paddle
.
zeros_like
(
text
)
align_start_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
align_end_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
align_start
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_end
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_start_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
sega_emb
=
False
mean_phn_span
=
0
mlm_prob
=
0.15
else
:
text
,
text_lengths
=
output
[
"text"
],
output
[
"text_lengths"
]
align_start
,
align_start_lengths
,
align_end
,
align_end_lengths
=
output
[
"align_start"
],
output
[
"align_start_lengths"
],
output
[
"align_end"
],
output
[
"align_end_lengths"
]
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
feats_extract
.
hop_length
).
int
()
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
feats_extract
.
hop_length
).
int
()
max_tlen
=
max
(
text_len
gths
).
item
(
)
max_slen
=
max
(
feats_len
gths
).
item
(
)
max_tlen
=
max
(
text_len
s
)
max_slen
=
max
(
feats_len
s
)
speech_pad
=
feats
[:,
:
max_slen
]
if
attention_window
>
0
and
pad_speech
:
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
...
...
@@ -888,51 +840,49 @@ def mlm_collate_fn(
else
:
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_len
gths
.
tolist
()
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
text_len
s
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
if
attention_window
>
0
:
text_mask
=
text_mask
*
2
speech_mask
=
make_non_pad_mask
(
feats_len
gths
.
tolist
()
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_b
oundar
y
=
None
if
'span_b
oundar
y'
in
output
.
keys
():
span_b
oundary
=
output
[
'span_boundar
y'
]
feats_len
s
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_b
d
y
=
None
if
'span_b
d
y'
in
output
.
keys
():
span_b
dy
=
output
[
'span_bd
y'
]
if
text_masking
:
masked_pos
ition
,
text_masked_position
,
_
=
phones_text_masking
(
masked_pos
,
text_masked_pos
,
_
=
phones_text_masking
(
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
)
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
else
:
text_masked_pos
ition
=
np
.
zeros
(
text_pad
.
size
(
))
masked_pos
ition
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundar
y
)
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
masked_pos
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bd
y
)
output_dict
=
{}
if
duration_collect
and
'text'
in
output
:
reordered_i
ndex
,
speech_segment_pos
,
text_segment_pos
,
durations
,
feats_lengths
=
get_segment
_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_len
gth
s
,
sega_emb
,
masked_pos
ition
,
feats_length
s
)
reordered_i
dx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
=
get_seg
_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
,
masked_pos
,
feats_len
s
)
speech_mask
=
make_non_pad_mask
(
feats_lengths
.
tolist
(),
speech_pad
[:,
:
reordered_index
.
shape
[
1
],
0
],
feats_lens
,
speech_pad
[:,
:
reordered_idx
.
shape
[
1
],
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
output_dict
[
'durations'
]
=
durations
output_dict
[
'reordered_i
ndex'
]
=
reordered_inde
x
output_dict
[
'reordered_i
dx'
]
=
reordered_id
x
else
:
speech_seg
ment_pos
,
text_segment_pos
=
get_segment_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
sega_emb
)
speech_seg
_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos
ition'
]
=
masked_position
output_dict
[
'text_masked_pos
ition'
]
=
text_masked_position
output_dict
[
'masked_pos
'
]
=
masked_pos
output_dict
[
'text_masked_pos
'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg
ment_pos'
]
=
speech_segment
_pos
output_dict
[
'text_seg
ment_pos'
]
=
text_segment
_pos
output_dict
[
'speech_len
gths'
]
=
output
[
"speech_length
s"
]
output_dict
[
'text_len
gths'
]
=
text_length
s
output_dict
[
'speech_seg
_pos'
]
=
speech_seg
_pos
output_dict
[
'text_seg
_pos'
]
=
text_seg
_pos
output_dict
[
'speech_len
s'
]
=
output
[
"speech_len
s"
]
output_dict
[
'text_len
s'
]
=
text_len
s
output
=
(
uttids
,
output_dict
)
return
output
...
...
@@ -940,13 +890,13 @@ def mlm_collate_fn(
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str,
torch.
Tensor]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class
=
LogMelFBank
if
args
.
feats_extract_conf
[
'win_length'
]
==
None
:
if
args
.
feats_extract_conf
[
'win_length'
]
is
None
:
args
.
feats_extract_conf
[
'win_length'
]
=
args
.
feats_extract_conf
[
'n_fft'
]
args_dic
=
{}
...
...
@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
args_dic
[
'sr'
]
=
v
else
:
args_dic
[
k
]
=
v
# feats_extract = feats_extract_class(**args.feats_extract_conf)
feats_extract
=
feats_extract_class
(
**
args_dic
)
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
...
...
@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_probs
=
[
1.0
,
1.0
,
0.7
,
0.6
,
0.5
]
mlm_prob_factor
=
0.8
#mlm_probs[epoch // 100]
mlm_prob_factor
=
0.8
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
(
)
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
duration_collect
=
True
...
...
@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
duration_collect
=
duration_collect
)
def
get_mlm_output
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
None
,
decoder
=
False
,
use_teacher_forcing
=
False
,
dynamic_eval
=
(
0
,
0
),
duration_adjust
=
True
,
start_end_sp
=
False
):
def
get_mlm_output
(
uid
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
):
mlm_model
,
train_args
=
load_model
(
model_name
)
mlm_model
.
eval
()
processor
=
None
collate_fn
=
build_collate_fn
(
train_args
,
False
)
return
decode_with_model
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
collate_fn
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
processor
=
processor
,
collate_fn
=
collate_fn
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
decoder
=
decoder
,
use_teacher_forcing
=
use_teacher_forcing
,
...
...
@@ -1033,23 +976,20 @@ def get_mlm_output(uid,
train_args
=
train_args
)
def
test_vctk
(
uid
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
vocoder
,
prefix
=
'dump/raw/dev'
,
model_name
=
"conformer"
,
old_str
=
""
,
new_str
=
""
,
prompt_decoding
=
False
,
dynamic_eval
=
(
0
,
0
),
task_name
=
None
):
def
evaluate
(
uid
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
use_pt_vocoder
:
bool
=
False
,
prefix
:
str
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
prompt_decoding
:
bool
=
False
,
task_name
:
str
=
None
):
duration_preditor_path
=
None
spemd
=
None
full_origin_str
,
wav_path
=
read_data
(
uid
,
prefix
)
full_origin_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
new_str
=
new_str
...
...
@@ -1065,19 +1005,17 @@ def test_vctk(uid,
old_str
=
full_origin_str
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
vocoder
,
duration_preditor_path
,
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
full_origin_str
=
full_origin_str
,
old_str
=
old_str
,
new_str
=
new_str
,
use_pt_vocoder
=
use_pt_vocoder
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
spemd
)
return
results_dict
...
...
@@ -1086,17 +1024,14 @@ if __name__ == "__main__":
# parse config and args
args
=
parse_args
()
data_dict
=
test_vctk
(
args
.
uid
,
args
.
clone_uid
,
args
.
clone_prefix
,
args
.
source_language
,
args
.
target_language
,
args
.
use_pt_vocoder
,
args
.
prefix
,
args
.
model_name
,
data_dict
=
evaluate
(
uid
=
args
.
uid
,
source_lang
=
args
.
source_lang
,
target_lang
=
args
.
target_lang
,
use_pt_vocoder
=
args
.
use_pt_vocoder
,
prefix
=
args
.
prefix
,
model_name
=
args
.
model_name
,
new_str
=
args
.
new_str
,
task_name
=
args
.
task_name
)
sf
.
write
(
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
print
(
"finished..."
)
# exit()
ernie-sat/model_paddle.py
浏览文件 @
b81832ce
...
...
@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_position
=
None
)
->
paddle
.
Tensor
:
masked_position
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_position
,
-
1
),
input
)
masked_input
=
masked_fill
(
input
,
masked_position
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_position
,
0
)
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
=
None
)
->
paddle
.
Tensor
:
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
return
masked_input
...
...
@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
def
forward
(
self
,
speech_pad
,
text_pad
,
masked_pos
ition
,
masked_pos
,
speech_mask
=
None
,
text_mask
=
None
,
speech_seg
ment
_pos
=
None
,
text_seg
ment
_pos
=
None
):
speech_seg_pos
=
None
,
text_seg_pos
=
None
):
"""Encode input sequence.
"""
if
masked_pos
ition
is
not
None
:
speech_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
ition
)
if
masked_pos
is
not
None
:
speech_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
)
else
:
speech_pad
=
self
.
speech_embed
(
speech_pad
)
# pure speech input
if
-
2
in
np
.
array
(
text_pad
):
text_pad
=
text_pad
+
3
text_mask
=
paddle
.
unsqueeze
(
bool
(
text_pad
),
1
)
text_seg
ment
_pos
=
paddle
.
zeros_like
(
text_pad
)
text_seg_pos
=
paddle
.
zeros_like
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg
ment
_pos
),
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg_pos
),
text_pad
[
1
])
text_seg
ment
_pos
=
None
text_seg_pos
=
None
elif
text_pad
is
not
None
:
text_pad
=
self
.
text_embed
(
text_pad
)
segment_emb
=
None
if
speech_segment_pos
is
not
None
and
text_segment_pos
is
not
None
and
self
.
segment_emb
:
speech_segment_emb
=
self
.
segment_emb
(
speech_segment_pos
)
text_segment_emb
=
self
.
segment_emb
(
text_segment_pos
)
text_pad
=
(
text_pad
[
0
]
+
text_segment_emb
,
text_pad
[
1
])
speech_pad
=
(
speech_pad
[
0
]
+
speech_segment_emb
,
speech_pad
[
1
])
segment_emb
=
paddle
.
concat
(
[
speech_segment_emb
,
text_segment_emb
],
axis
=
1
)
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text_pad
=
(
text_pad
[
0
]
+
text_seg_emb
,
text_pad
[
1
])
speech_pad
=
(
speech_pad
[
0
]
+
speech_seg_emb
,
speech_pad
[
1
])
if
self
.
pre_speech_encoders
:
speech_pad
,
_
=
self
.
pre_speech_encoders
(
speech_pad
,
speech_mask
)
...
...
@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
masks
#, segment_emb
return
xs
,
masks
class
MLMDecoder
(
MLMEncoder
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
ition
=
None
,
segment_emb
=
None
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
=
None
,
segment_emb
=
None
):
"""Encode input sequence.
Args:
...
...
@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
"""
emb
,
mlm_position
=
None
,
None
if
not
self
.
training
:
masked_pos
ition
=
None
masked_pos
=
None
xs
=
self
.
embed
(
xs
)
if
segment_emb
:
xs
=
(
xs
[
0
]
+
segment_emb
,
xs
[
1
])
...
...
@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
def
collect_feats
(
self
,
speech
,
speech_len
gth
s
,
speech_lens
,
text
,
text_len
gth
s
,
masked_pos
ition
,
text_lens
,
masked_pos
,
speech_mask
,
text_mask
,
speech_seg
ment
_pos
,
text_seg
ment
_pos
,
speech_seg_pos
,
text_seg_pos
,
y_masks
=
None
)
->
Dict
[
str
,
paddle
.
Tensor
]:
return
{
"feats"
:
speech
,
"feats_len
gths"
:
speech_length
s
}
return
{
"feats"
:
speech
,
"feats_len
s"
:
speech_len
s
}
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
...
...
@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
ys_in
,
y_masks
,
encoder_out
,
bool
(
h_masks
),
self
.
encoder
.
segment_emb
(
speech_seg
ment
_pos
))
self
.
encoder
.
segment_emb
(
speech_seg_pos
))
speech_hidden_states
=
zs
else
:
speech_hidden_states
=
encoder_out
[:,
:
paddle
.
shape
(
batch
[
...
...
@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos
ition
'
]
'masked_pos'
]
def
inference
(
self
,
speech
,
text
,
masked_pos
ition
,
masked_pos
,
speech_mask
,
text_mask
,
speech_seg
ment
_pos
,
text_seg
ment
_pos
,
span_b
oundar
y
,
speech_seg_pos
,
text_seg_pos
,
span_b
d
y
,
y_masks
=
None
,
speech_len
gth
s
=
None
,
text_len
gth
s
=
None
,
speech_lens
=
None
,
text_lens
=
None
,
feats
:
Optional
[
paddle
.
Tensor
]
=
None
,
spembs
:
Optional
[
paddle
.
Tensor
]
=
None
,
sids
:
Optional
[
paddle
.
Tensor
]
=
None
,
...
...
@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
batch
=
dict
(
speech_pad
=
speech
,
text_pad
=
text
,
masked_pos
ition
=
masked_position
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg
ment_pos
=
speech_segment
_pos
,
text_seg
ment_pos
=
text_segment
_pos
,
)
speech_seg
_pos
=
speech_seg
_pos
,
text_seg
_pos
=
text_seg
_pos
,
)
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs
=
[
batch
[
'speech_pad'
][:,
:
span_b
oundar
y
[
0
]]]
outs
=
[
batch
[
'speech_pad'
][:,
:
span_b
d
y
[
0
]]]
z_cache
=
None
if
use_teacher_forcing
:
before
,
zs
,
_
,
_
=
self
.
forward
(
batch
,
speech_seg
ment
_pos
,
y_masks
=
y_masks
)
batch
,
speech_seg_pos
,
y_masks
=
y_masks
)
if
zs
is
None
:
zs
=
before
outs
+=
[
zs
[
0
][
span_b
oundary
[
0
]:
span_boundar
y
[
1
]]]
outs
+=
[
batch
[
'speech_pad'
][:,
span_b
oundar
y
[
1
]:]]
outs
+=
[
zs
[
0
][
span_b
dy
[
0
]:
span_bd
y
[
1
]]]
outs
+=
[
batch
[
'speech_pad'
][:,
span_b
d
y
[
1
]:]]
return
dict
(
feat_gen
=
outs
)
return
None
...
...
@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
class
MLMEncAsDecoderModel
(
MLMModel
):
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
...
...
@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos
ition
'
]
'masked_pos'
]
class
MLMDualMaksingModel
(
MLMModel
):
...
...
@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
batch
):
xs_pad
=
batch
[
'speech_pad'
]
text_pad
=
batch
[
'text_pad'
]
masked_pos
ition
=
batch
[
'masked_position
'
]
text_masked_pos
ition
=
batch
[
'text_masked_position
'
]
mlm_loss_pos
ition
=
masked_position
>
0
masked_pos
=
batch
[
'masked_pos
'
]
text_masked_pos
=
batch
[
'text_masked_pos
'
]
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
...
...
@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
ition
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_position
)
+
1e-10
)
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text_pad
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_position
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_position
)
+
1e-10
)
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
...
...
@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
[
0
,
2
,
1
])
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos
ition'],batch['text_masked_position
']
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos
'],batch['text_masked_pos
']
def
build_model_from_file
(
config_file
,
model_file
):
...
...
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
浏览文件 @
b81832ce
...
...
@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
"""
n_batch
=
len
(
xs
)
max_len
=
max
(
x
.
shape
[
0
]
for
x
in
xs
)
pad
=
paddle
.
full
([
n_batch
,
max_len
,
*
xs
[
0
].
shape
[
1
:]],
pad_value
)
pad
=
paddle
.
full
([
n_batch
,
max_len
,
*
xs
[
0
].
shape
[
1
:]],
pad_value
,
dtype
=
xs
[
0
].
dtype
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
xs
[
i
].
shape
[
0
]]
=
xs
[
i
]
...
...
@@ -46,13 +46,18 @@ def pad_list(xs, pad_value):
return
pad
def
make_pad_mask
(
lengths
,
length_dim
=-
1
):
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of padded part.
Args:
lengths (Tensor(int64)): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Returns:
Tensor(bool): Mask tensor containing indices of padded part bool.
Examples:
...
...
@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1):
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]])
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]],)
"""
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
bs
=
paddle
.
shape
(
lengths
)[
0
]
maxlen
=
lengths
.
max
()
if
xs
is
None
:
maxlen
=
lengths
.
max
()
else
:
maxlen
=
paddle
.
shape
(
xs
)[
length_dim
]
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
([
bs
,
maxlen
])
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
if
xs
is
not
None
:
assert
paddle
.
shape
(
xs
)[
0
]
==
bs
,
(
paddle
.
shape
(
xs
)[
0
],
bs
)
if
length_dim
<
0
:
length_dim
=
len
(
paddle
.
shape
(
xs
))
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
len
(
paddle
.
shape
(
xs
))))
mask
=
paddle
.
expand
(
mask
[
ind
],
paddle
.
shape
(
xs
))
return
mask
def
make_non_pad_mask
(
lengths
,
length_dim
=-
1
):
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
Args:
...
...
@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1):
Returns:
Tensor(bool): mask tensor containing indices of padded part bool.
Examples:
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]])
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
"""
return
paddle
.
logical_not
(
make_pad_mask
(
lengths
,
length_dim
))
return
paddle
.
logical_not
(
make_pad_mask
(
lengths
,
xs
,
length_dim
))
def
initialize
(
model
:
nn
.
Layer
,
init
:
str
):
...
...
ernie-sat/run_clone_en_to_zh.sh
浏览文件 @
b81832ce
...
...
@@ -10,8 +10,8 @@ python inference.py \
--uid
=
Prompt_003_new
\
--new_str
=
'今天天气很好.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--target_lang
uage
=
chinese
\
--source_lang
=
english
\
--target_lang
=
chinese
\
--output_name
=
pred_clone.wav
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/run_gen_en.sh
浏览文件 @
b81832ce
...
...
@@ -9,8 +9,8 @@ python inference.py \
--uid
=
p299_096
\
--new_str
=
'I enjoy my life, do you?'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--target_lang
uage
=
english
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_gen.wav
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/run_sedit_en.sh
浏览文件 @
b81832ce
...
...
@@ -10,8 +10,8 @@ python inference.py \
--uid
=
p243_new
\
--new_str
=
'for that reason cover is impossible to be given.'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--target_lang
uage
=
english
\
--source_lang
=
english
\
--target_lang
=
english
\
--output_name
=
pred_edit.wav
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/sedit_arg_parser.py
浏览文件 @
b81832ce
...
...
@@ -80,10 +80,8 @@ def parse_args():
parser
.
add_argument
(
"--uid"
,
type
=
str
,
help
=
"uid"
)
parser
.
add_argument
(
"--new_str"
,
type
=
str
,
help
=
"new string"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
help
=
"prefix"
)
parser
.
add_argument
(
"--clone_prefix"
,
type
=
str
,
default
=
None
,
help
=
"clone prefix"
)
parser
.
add_argument
(
"--clone_uid"
,
type
=
str
,
default
=
None
,
help
=
"clone uid"
)
parser
.
add_argument
(
"--source_language"
,
type
=
str
,
help
=
"source language"
)
parser
.
add_argument
(
"--target_language"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--source_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"source language"
)
parser
.
add_argument
(
"--target_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"target language"
)
parser
.
add_argument
(
"--output_name"
,
type
=
str
,
help
=
"output name"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
help
=
"task name"
)
parser
.
add_argument
(
...
...
ernie-sat/tools/
parallel_wavegan_pretrained_vocoder
.py
→
ernie-sat/tools/
torch_pwgan
.py
浏览文件 @
b81832ce
...
...
@@ -9,7 +9,7 @@ import torch
import
yaml
class
ParallelWaveGANPretrainedVocoder
(
torch
.
nn
.
Module
):
class
TorchPWGAN
(
torch
.
nn
.
Module
):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def
__init__
(
...
...
ernie-sat/utils.py
浏览文件 @
b81832ce
import
os
from
typing
import
List
from
typing
import
Optional
import
numpy
as
np
import
paddle
import
yaml
...
...
@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
tools.parallel_wavegan_pretrained_vocoder
import
ParallelWaveGANPretrainedVocoder
# new add
from
tools.torch_pwgan
import
TorchPWGAN
model_alias
=
{
# acoustic model
...
...
@@ -25,6 +26,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2"
,
"tacotron2_inference"
:
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference"
,
"pwgan"
:
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator"
,
"pwgan_inference"
:
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
}
...
...
@@ -43,60 +48,65 @@ def build_vocoder_from_file(
# Build vocoder
if
str
(
vocoder_file
).
endswith
(
".pkl"
):
# If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder
=
ParallelWaveGANPretrainedVocoder
(
vocoder_file
,
vocoder_config_file
)
vocoder
=
TorchPWGAN
(
vocoder_file
,
vocoder_config_file
)
return
vocoder
.
to
(
device
)
else
:
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
def
get_voc_out
(
mel
,
target_lang
uage
=
"chinese"
):
def
get_voc_out
(
mel
,
target_lang
:
str
=
"chinese"
):
# vocoder
args
=
parse_args
()
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"In get_voc_out function, target_language
is illegal..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In get_voc_out function, target_lang
is illegal..."
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
# print(voc_config)
voc_inference
=
get_voc_inference
(
args
,
voc_config
)
voc_inference
=
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_stat
=
args
.
voc_stat
)
mel
=
paddle
.
to_tensor
(
mel
)
# print("masked_mel: ", mel.shape)
with
paddle
.
no_grad
():
wav
=
voc_inference
(
mel
)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return
np
.
squeeze
(
wav
)
# dygraph
def
get_am_inference
(
args
,
am_config
):
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
def
get_am_inference
(
am
:
str
=
'fastspeech2_csmsc'
,
am_config
:
CfgNode
=
None
,
am_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
am_stat
:
Optional
[
os
.
PathLike
]
=
None
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
return_am
:
bool
=
False
):
with
open
(
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
#
print("vocab_size:", vocab_size)
print
(
"vocab_size:"
,
vocab_size
)
tone_size
=
None
if
'tones_dict'
in
args
and
args
.
tones_dict
:
with
open
(
args
.
tones_dict
,
"r"
)
as
f
:
if
tones_dict
is
not
None
:
with
open
(
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
print
(
"tone_size:"
,
tone_size
)
spk_num
=
None
if
'speaker_dict'
in
args
and
args
.
speaker_dict
:
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
if
speaker_dict
is
not
None
:
with
open
(
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
print
(
"spk_num:"
,
spk_num
)
odim
=
am_config
.
n_mels
# model: {model_name}_{dataset}
am_name
=
a
rgs
.
am
[:
args
.
am
.
rindex
(
'_'
)]
am_dataset
=
a
rgs
.
am
[
args
.
am
.
rindex
(
'_'
)
+
1
:]
am_name
=
a
m
[:
am
.
rindex
(
'_'
)]
am_dataset
=
a
m
[
am
.
rindex
(
'_'
)
+
1
:]
am_class
=
dynamic_import
(
am_name
,
model_alias
)
am_inference_class
=
dynamic_import
(
am_name
+
'_inference'
,
model_alias
)
...
...
@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
elif
am_name
==
'tacotron2'
:
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
**
am_config
[
"model"
])
am
.
set_state_dict
(
paddle
.
load
(
a
rgs
.
a
m_ckpt
)[
"main_params"
])
am
.
set_state_dict
(
paddle
.
load
(
am_ckpt
)[
"main_params"
])
am
.
eval
()
am_mu
,
am_std
=
np
.
load
(
a
rgs
.
a
m_stat
)
am_mu
,
am_std
=
np
.
load
(
am_stat
)
am_mu
=
paddle
.
to_tensor
(
am_mu
)
am_std
=
paddle
.
to_tensor
(
am_std
)
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
am_inference
.
eval
()
print
(
"acoustic model done!"
)
return
am
,
am_inference
,
am_name
,
am_dataset
,
phn_id
if
return_am
:
return
am_inference
,
am
else
:
return
am_inference
def
evaluate_durations
(
phns
,
target_language
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
def
get_voc_inference
(
voc
:
str
=
'pwgan_csmsc'
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
):
# model: {model_name}_{dataset}
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
voc_class
=
dynamic_import
(
voc_name
,
model_alias
)
voc_inference_class
=
dynamic_import
(
voc_name
+
'_inference'
,
model_alias
)
if
voc_name
!=
'wavernn'
:
voc
=
voc_class
(
**
voc_config
[
"generator_params"
])
voc
.
set_state_dict
(
paddle
.
load
(
voc_ckpt
)[
"generator_params"
])
voc
.
remove_weight_norm
()
voc
.
eval
()
else
:
voc
=
voc_class
(
**
voc_config
[
"model"
])
voc
.
set_state_dict
(
paddle
.
load
(
voc_ckpt
)[
"main_params"
])
voc
.
eval
()
voc_mu
,
voc_std
=
np
.
load
(
voc_stat
)
voc_mu
=
paddle
.
to_tensor
(
voc_mu
)
voc_std
=
paddle
.
to_tensor
(
voc_std
)
voc_normalizer
=
ZScore
(
voc_mu
,
voc_std
)
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
voc_inference
.
eval
()
print
(
"voc done!"
)
return
voc_inference
def
evaluate_durations
(
phns
:
List
[
str
],
target_lang
:
str
=
"chinese"
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
args
=
parse_args
()
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
args
.
lang
=
'en'
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
uage
==
'chinese'
:
elif
target_lang
==
'chinese'
:
args
.
lang
=
'zh'
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
...
...
@@ -155,23 +187,28 @@ def evaluate_durations(phns,
else
:
print
(
"ngpu should >= 0 !"
)
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"In evaluate_durations function, target_language
is illegal..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In evaluate_durations function, target_lang
is illegal..."
# Init body.
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
# print("========Config========")
# print(am_config)
# print("---------------------")
# acoustic model
am
,
am_inference
,
am_name
,
am_dataset
,
phn_id
=
get_am_inference
(
args
,
am_config
)
am_inference
,
am
=
get_am_inference
(
am
=
args
.
am
,
am_config
=
am_config
,
am_ckpt
=
args
.
am_ckpt
,
am_stat
=
args
.
am_stat
,
phones_dict
=
args
.
phones_dict
,
tones_dict
=
args
.
tones_dict
,
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
torch_phns
=
phns
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
# print("vocab_phones: ", len(vocab_phones))
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_phns
]
...
...
@@ -185,59 +222,3 @@ def evaluate_durations(phns,
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
return
phoneme_durations_new
def
sentence2phns
(
sentence
,
target_language
=
"en"
):
args
=
parse_args
()
if
target_language
==
'en'
:
args
.
lang
=
'en'
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_language
==
'zh'
:
args
.
lang
=
'zh'
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else
:
print
(
"target_language should in {'zh', 'en'}!"
)
frontend
=
get_frontend
(
args
)
merge_sentences
=
True
get_tone_ids
=
False
if
target_language
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
,
print_info
=
False
)
phone_ids
=
input_ids
[
"phone_ids"
]
phonemes
=
frontend
.
get_phonemes
(
sentence
,
merge_sentences
=
merge_sentences
,
print_info
=
False
)
return
phonemes
[
0
],
input_ids
[
"phone_ids"
][
0
]
elif
target_language
==
'en'
:
phonemes
=
frontend
.
phoneticize
(
sentence
)
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phones_list
=
[]
vocab_phones
=
{}
punc
=
":,;。?!“”‘’':,;.?!"
with
open
(
args
.
phones_dict
,
'rt'
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
phn
,
id
in
phn_id
:
vocab_phones
[
phn
]
=
int
(
id
)
phones
=
phonemes
[
1
:
-
1
]
phones
=
[
phn
for
phn
in
phones
if
not
phn
.
isspace
()]
# replace unk phone with sp
phones
=
[
phn
if
(
phn
in
vocab_phones
and
phn
not
in
punc
)
else
"sp"
for
phn
in
phones
]
phones_list
.
append
(
phones
)
return
phones_list
[
0
],
input_ids
[
"phone_ids"
][
0
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录