Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
b81832ce
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b81832ce
编写于
6月 10, 2022
作者:
K
Kennycao123
提交者:
GitHub
6月 10, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #825 from yt605155624/format
[ernie sat]Format ernie sat
上级
437081f7
76b654cb
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
980 addition
and
1123 deletion
+980
-1123
ernie-sat/README.md
ernie-sat/README.md
+2
-2
ernie-sat/align.py
ernie-sat/align.py
+117
-46
ernie-sat/align_mandarin.py
ernie-sat/align_mandarin.py
+0
-186
ernie-sat/dataset.py
ernie-sat/dataset.py
+209
-285
ernie-sat/inference.py
ernie-sat/inference.py
+351
-416
ernie-sat/model_paddle.py
ernie-sat/model_paddle.py
+51
-59
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
+154
-12
ernie-sat/run_clone_en_to_zh.sh
ernie-sat/run_clone_en_to_zh.sh
+2
-2
ernie-sat/run_gen_en.sh
ernie-sat/run_gen_en.sh
+2
-2
ernie-sat/run_sedit_en.sh
ernie-sat/run_sedit_en.sh
+2
-2
ernie-sat/sedit_arg_parser.py
ernie-sat/sedit_arg_parser.py
+2
-4
ernie-sat/tools/torch_pwgan.py
ernie-sat/tools/torch_pwgan.py
+1
-1
ernie-sat/utils.py
ernie-sat/utils.py
+87
-106
未找到文件。
ernie-sat/README.md
浏览文件 @
b81832ce
...
@@ -113,8 +113,8 @@ prompt/dev
...
@@ -113,8 +113,8 @@ prompt/dev
8.
` --uid`
特定提示(prompt)语音的 id
8.
` --uid`
特定提示(prompt)语音的 id
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
10.
` --prefix`
特定音频对应的文本、音素相关文件的地址
10.
` --prefix`
特定音频对应的文本、音素相关文件的地址
11.
` --source_lang
uage
`
, 源语言
11.
` --source_lang`
, 源语言
12.
` --target_lang
uage
`
, 目标语言
12.
` --target_lang`
, 目标语言
13.
` --output_name`
, 合成语音名称
13.
` --output_name`
, 合成语音名称
14.
` --task_name`
, 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
14.
` --task_name`
, 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15.
` --use_pt_vocoder`
, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
15.
` --use_pt_vocoder`
, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
...
...
ernie-sat/align
_english
.py
→
ernie-sat/align.py
浏览文件 @
b81832ce
#!/usr/bin/env python
#!/usr/bin/env python
""" Usage:
""" Usage:
align
_english
.py wavfile trsfile outwordfile outphonefile
align.py wavfile trsfile outwordfile outphonefile
"""
"""
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
...
@@ -9,12 +9,45 @@ import sys
...
@@ -9,12 +9,45 @@ import sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR
=
'tools/aligner/english'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
HVITE
=
'tools/htk/HTKTools/HVite'
HVITE
=
'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
prep_txt
(
line
,
tmpbase
,
dictfile
):
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
not
in
ds
):
unk_words
.
add
(
wrd
)
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
return
unk_words
def
prep_txt_en
(
line
:
str
,
tmpbase
,
dictfile
):
words
=
[]
words
=
[]
...
@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
...
@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
fw
.
close
()
fw
.
close
()
def
prep_mlf
(
txt
,
tmpbase
):
def
prep_mlf
(
txt
:
str
,
tmpbase
:
str
):
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
fwid
.
write
(
'#!MLF!#
\n
'
)
fwid
.
write
(
'#!MLF!#
\n
'
)
...
@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
...
@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
fwid
.
write
(
'.
\n
'
)
fwid
.
write
(
'.
\n
'
)
def
gen_res
(
tmpbase
,
outfile1
,
outfile2
):
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
"/"
)[
-
1
]
def
alignment
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt_en
(
text
,
tmpbase
,
MODEL_DIR_EN
+
'/dict'
)
except
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR_EN
+
'/16000/macros -H '
+
MODEL_DIR_EN
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
=
txt
.
strip
().
split
()
...
@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
...
@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
lines
=
fid
.
readlines
()
i
=
2
i
=
2
times1
=
[]
times2
=
[]
times2
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
while
(
i
<
len
(
lines
)):
if
(
len
(
lines
[
i
].
split
())
>=
4
)
and
(
splited_line
=
lines
[
i
].
strip
().
split
()
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()
[
1
]):
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
lines
[
i
].
split
()
[
2
]
phn
=
splited_line
[
2
]
pst
=
(
int
(
lines
[
i
].
split
()
[
0
])
/
1000
+
125
)
/
10000
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
lines
[
i
].
split
()
[
1
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
times2
.
append
([
phn
,
pst
,
pen
])
if
(
len
(
lines
[
i
].
split
())
==
5
):
# splited_line[-1]!='sp'
if
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
if
len
(
splited_line
)
==
5
:
wrd
=
lines
[
i
].
split
()[
-
1
].
strip
()
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
st
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
word2phns
[
current_word
]
=
phn
j
=
i
+
1
index
+=
1
while
(
lines
[
j
]
!=
'.
\n
'
)
and
(
len
(
lines
[
j
].
split
())
!=
5
):
elif
len
(
splited_line
)
==
4
:
j
+=
1
word2phns
[
current_word
]
+=
' '
+
phn
en
=
(
int
(
lines
[
j
-
1
].
split
()[
1
])
/
1000
+
125
)
/
10000
times1
.
append
([
wrd
,
st
,
en
])
i
+=
1
i
+=
1
return
times2
,
word2phns
with
open
(
outfile1
,
'w'
)
as
fwid
:
for
item
in
times1
:
if
(
item
[
0
]
==
'sp'
):
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' SIL
\n
'
)
else
:
wrd
=
words
.
pop
()
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
wrd
+
'
\n
'
)
if
words
:
print
(
'not matched::'
+
alignfile
)
sys
.
exit
(
1
)
with
open
(
outfile2
,
'w'
)
as
fwid
:
for
item
in
times2
:
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
item
[
0
]
+
'
\n
'
)
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
"/"
)[
-
1
]
def
alignment
(
wav_path
,
text_string
):
def
alignment
_zh
(
wav_path
,
text_string
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
#prepare wav and trs files
try
:
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
:
print
(
'sox error!'
)
print
(
'sox error!'
)
return
None
return
None
#prepare clean_transcript file
#prepare clean_transcript file
try
:
try
:
prep_txt
(
text_string
,
tmpbase
,
MODEL_DIR
+
'/dict'
)
unk_words
=
prep_txt_zh
(
text_string
,
tmpbase
,
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
except
:
print
(
'prep_txt error!'
)
print
(
'prep_txt error!'
)
return
None
return
None
...
@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
...
@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
#prepare scp
#prepare scp
try
:
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
+
'/16000/config '
+
tmpbase
+
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
:
print
(
'HCopy error!'
)
print
(
'HCopy error!'
)
...
@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
...
@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
#run alignment
#run alignment
try
:
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR
+
'/16000/macros -H '
+
MODEL_DIR
+
'.mlf -H '
+
MODEL_DIR
_ZH
+
'/16000/macros -H '
+
MODEL_DIR_ZH
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
MODEL_DIR_ZH
'.dict '
+
MODEL_DIR
+
'/monophones '
+
tmpbase
+
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
'.plp 2>&1 > /dev/null'
)
except
:
except
:
print
(
'HVite error!'
)
print
(
'HVite error!'
)
return
None
return
None
...
@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
...
@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
lines
=
fid
.
readlines
()
i
=
2
i
=
2
times2
=
[]
times2
=
[]
word2phns
=
{}
word2phns
=
{}
...
...
ernie-sat/align_mandarin.py
已删除
100755 → 0
浏览文件 @
437081f7
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import
multiprocessing
as
mp
import
os
import
sys
from
tqdm
import
tqdm
MODEL_DIR
=
'tools/aligner/mandarin'
HVITE
=
'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
prep_txt
(
line
,
tmpbase
,
dictfile
):
words
=
[]
line
=
line
.
strip
()
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
ds
.
add
(
line
.
split
()[
0
])
unk_words
=
set
([])
with
open
(
tmpbase
+
'.txt'
,
'w'
)
as
fwid
:
for
wrd
in
words
:
if
(
wrd
not
in
ds
):
unk_words
.
add
(
wrd
)
fwid
.
write
(
wrd
+
' '
)
fwid
.
write
(
'
\n
'
)
return
unk_words
def
prep_mlf
(
txt
,
tmpbase
):
with
open
(
tmpbase
+
'.mlf'
,
'w'
)
as
fwid
:
fwid
.
write
(
'#!MLF!#
\n
'
)
fwid
.
write
(
'"'
+
tmpbase
+
'.lab"
\n
'
)
fwid
.
write
(
'sp
\n
'
)
wrds
=
txt
.
split
()
for
wrd
in
wrds
:
fwid
.
write
(
wrd
.
upper
()
+
'
\n
'
)
fwid
.
write
(
'sp
\n
'
)
fwid
.
write
(
'.
\n
'
)
def
gen_res
(
tmpbase
,
outfile1
,
outfile2
):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times1
=
[]
times2
=
[]
while
(
i
<
len
(
lines
)):
if
(
len
(
lines
[
i
].
split
())
>=
4
)
and
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
phn
=
lines
[
i
].
split
()[
2
]
pst
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
lines
[
i
].
split
()[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
if
(
len
(
lines
[
i
].
split
())
==
5
):
if
(
lines
[
i
].
split
()[
0
]
!=
lines
[
i
].
split
()[
1
]):
wrd
=
lines
[
i
].
split
()[
-
1
].
strip
()
st
=
(
int
(
lines
[
i
].
split
()[
0
])
/
1000
+
125
)
/
10000
j
=
i
+
1
while
(
lines
[
j
]
!=
'.
\n
'
)
and
(
len
(
lines
[
j
].
split
())
!=
5
):
j
+=
1
en
=
(
int
(
lines
[
j
-
1
].
split
()[
1
])
/
1000
+
125
)
/
10000
times1
.
append
([
wrd
,
st
,
en
])
i
+=
1
with
open
(
outfile1
,
'w'
)
as
fwid
:
for
item
in
times1
:
if
(
item
[
0
]
==
'sp'
):
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' SIL
\n
'
)
else
:
wrd
=
words
.
pop
()
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
wrd
+
'
\n
'
)
if
words
:
print
(
'not matched::'
+
alignfile
)
sys
.
exit
(
1
)
with
open
(
outfile2
,
'w'
)
as
fwid
:
for
item
in
times2
:
fwid
.
write
(
str
(
item
[
1
])
+
' '
+
str
(
item
[
2
])
+
' '
+
item
[
0
]
+
'
\n
'
)
def
alignment_zh
(
wav_path
,
text_string
):
tmpbase
=
'/tmp/'
+
os
.
environ
[
'USER'
]
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
unk_words
=
prep_txt
(
text_string
,
tmpbase
,
MODEL_DIR
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
print
(
'prep_txt error!'
)
return
None
#prepare mlf file
try
:
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
print
(
'prep_mlf error!'
)
return
None
#prepare scp
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
print
(
'HCopy error!'
)
return
None
#run alignment
try
:
os
.
system
(
HVITE
+
' -a -m -t 10000.0 10000.0 100000.0 -I '
+
tmpbase
+
'.mlf -H '
+
MODEL_DIR
+
'/16000/macros -H '
+
MODEL_DIR
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
MODEL_DIR
+
'/dict '
+
MODEL_DIR
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
print
(
'HVite error!'
)
return
None
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
words
=
fid
.
readline
().
strip
().
split
()
words
=
txt
.
strip
().
split
()
words
.
reverse
()
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
while
(
i
<
len
(
lines
)):
splited_line
=
lines
[
i
].
strip
().
split
()
if
(
len
(
splited_line
)
>=
4
)
and
(
splited_line
[
0
]
!=
splited_line
[
1
]):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
word2phns
[
current_word
]
=
phn
index
+=
1
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
ernie-sat/dataset.py
浏览文件 @
b81832ce
...
@@ -4,37 +4,180 @@ import numpy as np
...
@@ -4,37 +4,180 @@ import numpy as np
import
paddle
import
paddle
def
pad_list
(
xs
,
pad_value
):
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
"""Perform padding for the list of tensors.
src_mask
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
float
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
unmasked_phn_idxs
=
list
(
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
masked_text_idxs
=
unmasked_phn_idxs
[:
text_mask_num_lower
]
text_masked_pos
[
idx
][
masked_text_idxs
]
=
1
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
non_eos_text_mask
=
paddle
.
reshape
(
text_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
text_masked_pos
=
text_masked_pos
*
non_eos_text_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
return
masked_pos
,
text_masked_pos
,
y_masks
def
get_seg_pos_reduce_duration
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
,
masked_pos
:
paddle
.
Tensor
,
feats_lens
:
paddle
.
Tensor
,
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
text_seg_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
reordered_idx
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
durations
=
paddle
.
ones
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
max_reduced_length
=
0
if
not
sega_emb
:
return
speech_pad
,
masked_pos
,
speech_seg_pos
,
text_seg_pos
,
durations
for
idx
in
range
(
bz
):
first_idx
=
[]
last_idx
=
[]
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
if
j
==
0
:
if
paddle
.
sum
(
masked_pos
[
idx
][
0
:
s
])
==
0
:
first_idx
.
extend
(
range
(
0
,
s
))
else
:
first_idx
.
extend
([
0
])
last_idx
.
extend
(
range
(
1
,
s
))
if
paddle
.
sum
(
masked_pos
[
idx
][
s
:
e
])
==
0
:
first_idx
.
extend
(
range
(
s
,
e
))
else
:
first_idx
.
extend
([
s
])
last_idx
.
extend
(
range
(
s
+
1
,
e
))
durations
[
idx
][
s
]
=
e
-
s
speech_seg_pos
[
idx
][
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
][
j
]
=
j
+
1
max_reduced_length
=
max
(
len
(
first_idx
)
+
feats_lens
[
idx
]
-
e
,
max_reduced_length
)
first_idx
.
extend
(
range
(
e
,
speech_len
))
reordered_idx
[
idx
]
=
paddle
.
to_tensor
(
(
first_idx
+
last_idx
),
dtype
=
align_start_lens
.
dtype
)
feats_lens
[
idx
]
=
len
(
first_idx
)
reordered_idx
=
reordered_idx
[:,
:
max_reduced_length
]
return
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
,
mean_phn_span
:
float
):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
length: an int32 scalar (length of the incoming token sequence)
pad_value (float): Value for padding.
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
a boolean tensor with shape [length]
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
"""
n_batch
=
len
(
xs
)
max_len
=
max
(
paddle
.
shape
(
x
)[
0
]
for
x
in
xs
)
pad
=
paddle
.
full
((
n_batch
,
max_len
),
pad_value
,
dtype
=
xs
[
0
].
dtype
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
paddle
.
shape
(
xs
[
i
])[
0
]]
=
xs
[
i
]
return
pad
orig_length
=
length
num_noise_tokens
=
int
(
np
.
round
(
length
*
mlm_prob
))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens
=
min
(
max
(
num_noise_tokens
,
1
),
length
-
1
)
num_noise_spans
=
int
(
np
.
round
(
num_noise_tokens
/
mean_phn_span
))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans
=
max
(
num_noise_spans
,
1
)
num_nonnoise_tokens
=
length
-
num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def
_random_seg
(
num_items
,
num_segs
):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segs: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segs] containing positive integers that add
up to num_items
"""
mask_idxs
=
np
.
arange
(
num_items
-
1
)
<
(
num_segs
-
1
)
np
.
random
.
shuffle
(
mask_idxs
)
first_in_seg
=
np
.
pad
(
mask_idxs
,
[[
1
,
0
]])
segment_id
=
np
.
cumsum
(
first_in_seg
)
# count length of sub segments assuming that list is sorted
_
,
segment_length
=
np
.
unique
(
segment_id
,
return_counts
=
True
)
return
segment_length
noise_span_lens
=
_random_seg
(
num_noise_tokens
,
num_noise_spans
)
nonnoise_span_lens
=
_random_seg
(
num_nonnoise_tokens
,
num_noise_spans
)
interleaved_span_lens
=
np
.
reshape
(
np
.
stack
([
nonnoise_span_lens
,
noise_span_lens
],
axis
=
1
),
[
num_noise_spans
*
2
])
span_starts
=
np
.
cumsum
(
interleaved_span_lens
)[:
-
1
]
span_start_indicator
=
np
.
zeros
((
length
,
),
dtype
=
np
.
int8
)
span_start_indicator
[
span_starts
]
=
True
span_num
=
np
.
cumsum
(
span_start_indicator
)
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
return
is_noise
[:
orig_length
]
def
pad_to_longformer_att_window
(
text
:
paddle
.
Tensor
,
max_len
:
int
,
max_tlen
:
int
,
attention_window
:
int
=
0
):
def
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
):
round
=
max_len
%
attention_window
round
=
max_len
%
attention_window
if
round
!=
0
:
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
max_tlen
+=
(
attention_window
-
round
)
...
@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
...
@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
return
text_pad
,
max_tlen
return
text_pad
,
max_tlen
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
"""Make mask tensor containing indices of padded part.
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
Args:
align_end
:
paddle
.
Tensor
,
lengths (LongTensor or List): Batch of lengths (B,).
align_start_lens
:
paddle
.
Tensor
,
xs (Tensor, optional): The reference tensor.
mlm_prob
:
float
,
If set, masks will be the same shape as this tensor.
mean_phn_span
:
int
,
length_dim (int, optional): Dimension indicator of the above tensor.
span_bdy
:
paddle
.
Tensor
=
None
):
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
if
not
isinstance
(
lengths
,
list
):
lengths
=
list
(
lengths
)
bs
=
int
(
len
(
lengths
))
if
xs
is
None
:
maxlen
=
int
(
max
(
lengths
))
else
:
maxlen
=
paddle
.
shape
(
xs
)[
length_dim
]
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
paddle
.
expand
(
paddle
.
unsqueeze
(
seq_range
,
0
),
(
bs
,
maxlen
))
seq_length_expand
=
paddle
.
unsqueeze
(
paddle
.
to_tensor
(
lengths
),
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
if
xs
is
not
None
:
assert
paddle
.
shape
(
xs
)[
0
]
==
bs
,
(
paddle
.
shape
(
xs
)[
0
],
bs
)
if
length_dim
<
0
:
length_dim
=
len
(
paddle
.
shape
(
xs
))
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
len
(
paddle
.
shape
(
xs
))))
mask
=
paddle
.
expand
(
mask
[
ind
],
paddle
.
shape
(
xs
))
return
mask
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return
~
make_pad_mask
(
lengths
,
xs
,
length_dim
)
def
phones_masking
(
xs_pad
,
src_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
mask_num_lower
=
math
.
ceil
(
sent_len
*
mlm_prob
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
masked_position
=
np
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
y_masks
=
None
# y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype)
# tril_masks = torch.tril(y_masks)
if
mlm_prob
==
1.0
:
if
mlm_prob
==
1.0
:
masked_position
+=
1
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
elif
mean_phn_span
==
0
:
# only speech
# only speech
length
=
sent_len
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_i
ndice
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
masked_phn_i
dx
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
mean_phn_span
).
nonzero
()
masked_pos
ition
[:,
masked_phn_indice
s
]
=
1
masked_pos
[:,
masked_phn_idx
s
]
=
1
else
:
else
:
for
idx
in
range
(
bz
):
for
idx
in
range
(
bz
):
if
span_boundary
is
not
None
:
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_boundary
[
idx
][::
2
],
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
span_boundary
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
masked_position
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
else
:
else
:
length
=
align_start_len
gths
[
idx
].
item
()
length
=
align_start_len
s
[
idx
]
if
length
<
2
:
if
length
<
2
:
continue
continue
masked_phn_i
ndice
s
=
random_spans_noise_mask
(
masked_phn_i
dx
s
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_i
ndice
s
].
tolist
()
masked_start
=
align_start
[
idx
][
masked_phn_i
dx
s
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_i
ndice
s
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_i
dx
s
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_position
[
idx
,
s
:
e
]
=
1
masked_pos
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
# y_masks[idx, e:, s:e ] = 0
masked_pos
=
masked_pos
*
non_eos_mask
non_eos_mask
=
np
.
array
(
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
]).
float
().
cpu
())
masked_position
=
masked_position
*
non_eos_mask
# y_masks = src_mask & y_masks.bool()
return
paddle
.
cast
(
paddle
.
to_tensor
(
masked_position
),
paddle
.
bool
)
,
y_masks
return
masked_pos
,
y_masks
def
get_segment_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
align_start_lengths
,
sega_emb
):
text_pad
:
paddle
.
Tensor
,
bz
,
speech_len
,
_
=
speech_pad
.
size
()
align_start
:
paddle
.
Tensor
,
_
,
text_len
=
text_pad
.
size
()
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
# text_segment_pos = paddle.zeros_like(text_pad)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
# speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
text_segment_pos
=
np
.
zeros
((
bz
,
text_len
)).
astype
(
'int64'
)
speech_segment_pos
=
np
.
zeros
((
bz
,
speech_len
)).
astype
(
'int64'
)
if
not
sega_emb
:
if
not
sega_emb
:
text_segment_pos
=
paddle
.
to_tensor
(
text_segment_pos
)
return
speech_seg_pos
,
text_seg_pos
speech_segment_pos
=
paddle
.
to_tensor
(
speech_segment_pos
)
return
speech_segment_pos
,
text_segment_pos
for
idx
in
range
(
bz
):
for
idx
in
range
(
bz
):
align_length
=
align_start_len
gths
[
idx
].
item
()
align_length
=
align_start_len
s
[
idx
]
for
j
in
range
(
align_length
):
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
].
item
(),
align_end
[
idx
][
j
].
item
()
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
speech_segment_pos
[
idx
][
s
:
e
]
=
j
+
1
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_segment_pos
[
idx
][
j
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
text_segment_pos
=
paddle
.
to_tensor
(
text_segment_pos
)
speech_segment_pos
=
paddle
.
to_tensor
(
speech_segment_pos
)
return
speech_seg
ment_pos
,
text_segment
_pos
return
speech_seg
_pos
,
text_seg
_pos
ernie-sat/inference.py
浏览文件 @
b81832ce
#!/usr/bin/env python3
#!/usr/bin/env python3
import
argparse
import
argparse
import
math
import
os
import
os
import
pickle
import
random
import
random
import
string
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Collection
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
Dict
...
@@ -18,17 +14,17 @@ import numpy as np
...
@@ -18,17 +14,17 @@ import numpy as np
import
paddle
import
paddle
import
soundfile
as
sf
import
soundfile
as
sf
import
torch
import
torch
from
paddle
import
nn
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
align_english
import
alignment
from
align
import
alignment
from
align_mandarin
import
alignment_zh
from
align
import
alignment_zh
from
dataset
import
get_segment_pos
from
dataset
import
get_seg_pos
from
dataset
import
make_non_pad_mask
from
dataset
import
get_seg_pos_reduce_duration
from
dataset
import
make_pad_mask
from
dataset
import
pad_list
from
dataset
import
pad_to_longformer_att_window
from
dataset
import
pad_to_longformer_att_window
from
dataset
import
phones_masking
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
model_paddle
import
build_model_from_file
from
model_paddle
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2column_text
from
read_text
import
read_2column_text
...
@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file
...
@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file
from
utils
import
evaluate_durations
from
utils
import
evaluate_durations
from
utils
import
get_voc_out
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
utils
import
is_chinese
from
utils
import
sentence2phns
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
random
.
seed
(
0
)
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english'
...
@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
def
plot_mel_and_vocode_wav
(
uid
,
def
plot_mel_and_vocode_wav
(
uid
:
str
,
prefix
,
wav_path
:
str
,
clone_uid
,
prefix
:
str
=
"./prompt/dev/"
,
clone_prefix
,
source_lang
:
str
=
'english'
,
source_language
,
target_lang
:
str
=
'english'
,
target_language
,
model_name
:
str
=
"conformer"
,
model_name
,
full_origin_str
:
str
=
""
,
wav_path
,
old_str
:
str
=
""
,
full_origin_str
,
new_str
:
str
=
""
,
old_str
,
duration_preditor_path
:
str
=
None
,
new_str
,
use_pt_vocoder
:
bool
=
False
,
use_pt_vocoder
,
sid
:
str
=
None
,
duration_preditor_path
,
non_autoreg
:
bool
=
True
):
sid
=
None
,
wav_org
,
input_feat
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
non_autoreg
=
True
):
uid
=
uid
,
wav_org
,
input_feat
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
=
get_mlm_output
(
prefix
=
prefix
,
uid
,
source_lang
=
source_lang
,
prefix
,
target_lang
=
target_lang
,
clone_uid
,
model_name
=
model_name
,
clone_prefix
,
wav_path
=
wav_path
,
source_language
,
old_str
=
old_str
,
target_language
,
new_str
=
new_str
,
model_name
,
duration_preditor_path
=
duration_preditor_path
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
use_teacher_forcing
=
non_autoreg
,
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
)
sid
=
sid
)
masked_feat
=
output_feat
[
new_span_boundary
[
0
]:
new_span_boundary
[
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
1
]].
detach
().
float
().
cpu
().
numpy
()
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
if
use_pt_vocoder
:
if
use_pt_vocoder
:
output_feat
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
output_feat
=
output_feat
.
cpu
().
numpy
()
output_feat
=
torch
.
tensor
(
output_feat
,
dtype
=
torch
.
float
)
output_feat
=
torch
.
tensor
(
output_feat
,
dtype
=
torch
.
float
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
replaced_wav
=
vocoder
(
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
output_feat
).
detach
().
float
().
data
.
cpu
().
numpy
()
else
:
else
:
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav
=
get_voc_out
(
output_feat
,
target_lang
)
replaced_wav
=
get_voc_out
(
output_feat_np
,
target_language
)
elif
target_language
==
'chinese'
:
elif
target_lang
==
'chinese'
:
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_lang
)
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_language
)
old_time_b
oundary
=
[
hop_length
*
x
for
x
in
old_span_boundar
y
]
old_time_b
dy
=
[
hop_length
*
x
for
x
in
old_span_bd
y
]
new_time_b
oundary
=
[
hop_length
*
x
for
x
in
new_span_boundar
y
]
new_time_b
dy
=
[
hop_length
*
x
for
x
in
new_span_bd
y
]
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
wav_org_replaced_paddle_voc
=
np
.
concatenate
([
wav_org_replaced_paddle_voc
=
np
.
concatenate
([
wav_org
[:
old_time_b
oundar
y
[
0
]],
wav_org
[:
old_time_b
d
y
[
0
]],
replaced_wav
[
new_time_b
oundary
[
0
]:
new_time_boundar
y
[
1
]],
replaced_wav
[
new_time_b
dy
[
0
]:
new_time_bd
y
[
1
]],
wav_org
[
old_time_b
oundar
y
[
1
]:]
wav_org
[
old_time_b
d
y
[
1
]:]
])
])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_paddle_voc
}
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_paddle_voc
}
elif
target_lang
uage
==
'chinese'
:
elif
target_lang
==
'chinese'
:
wav_org_replaced_only_mask_fst2_voc
=
np
.
concatenate
([
wav_org_replaced_only_mask_fst2_voc
=
np
.
concatenate
([
wav_org
[:
old_time_b
oundar
y
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[:
old_time_b
d
y
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[
old_time_b
oundar
y
[
1
]:]
wav_org
[
old_time_b
d
y
[
1
]:]
])
])
data_dict
=
{
data_dict
=
{
"origin"
:
wav_org
,
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_only_mask_fst2_voc
,
"output"
:
wav_org_replaced_only_mask_fst2_voc
,
}
}
return
data_dict
,
old_span_b
oundar
y
return
data_dict
,
old_span_b
d
y
def
get_unk_phns
(
word_str
):
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
write
(
word_str
)
...
@@ -160,9 +148,8 @@ def get_unk_phns(word_str):
...
@@ -160,9 +148,8 @@ def get_unk_phns(word_str):
return
phns
return
phns
def
words2phns
(
line
):
def
words2phns
(
line
:
str
):
dictfile
=
MODEL_DIR_EN
+
'/dict'
dictfile
=
MODEL_DIR_EN
+
'/dict'
tmpbase
=
'/tmp/tp.'
line
=
line
.
strip
()
line
=
line
.
strip
()
words
=
[]
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
...
@@ -200,9 +187,8 @@ def words2phns(line):
...
@@ -200,9 +187,8 @@ def words2phns(line):
return
phns
,
wrd2phns
return
phns
,
wrd2phns
def
words2phns_zh
(
line
):
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
dictfile
=
MODEL_DIR_ZH
+
'/dict'
tmpbase
=
'/tmp/tp.'
line
=
line
.
strip
()
line
=
line
.
strip
()
words
=
[]
words
=
[]
for
pun
in
[
for
pun
in
[
...
@@ -242,7 +228,7 @@ def words2phns_zh(line):
...
@@ -242,7 +228,7 @@ def words2phns_zh(line):
return
phns
,
wrd2phns
return
phns
,
wrd2phns
def
load_vocoder
(
vocoder_tag
=
"vctk_parallel_wavegan.v1.long"
):
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
vocoder_config
=
Path
(
vocoder_file
).
parent
/
"config.yml"
vocoder_config
=
Path
(
vocoder_file
).
parent
/
"config.yml"
...
@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
...
@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
return
vocoder
return
vocoder
def
load_model
(
model_name
):
def
load_model
(
model_name
:
str
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
mlm_model
,
args
=
build_model_from_file
(
mlm_model
,
args
=
build_model_from_file
(
...
@@ -258,7 +244,7 @@ def load_model(model_name):
...
@@ -258,7 +244,7 @@ def load_model(model_name):
return
mlm_model
,
args
return
mlm_model
,
args
def
read_data
(
uid
,
prefix
):
def
read_data
(
uid
:
str
,
prefix
:
str
):
mfa_text
=
read_2column_text
(
prefix
+
'/text'
)[
uid
]
mfa_text
=
read_2column_text
(
prefix
+
'/text'
)[
uid
]
mfa_wav_path
=
read_2column_text
(
prefix
+
'/wav.scp'
)[
uid
]
mfa_wav_path
=
read_2column_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
'mnt'
not
in
mfa_wav_path
:
if
'mnt'
not
in
mfa_wav_path
:
...
@@ -266,7 +252,7 @@ def read_data(uid, prefix):
...
@@ -266,7 +252,7 @@ def read_data(uid, prefix):
return
mfa_text
,
mfa_wav_path
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
,
prefix
):
def
get_align_data
(
uid
:
str
,
prefix
:
str
):
mfa_path
=
prefix
+
"mfa_"
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2column_text
(
mfa_path
+
'text'
)[
uid
]
mfa_text
=
read_2column_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
mfa_start
=
load_num_sequence_text
(
...
@@ -277,43 +263,45 @@ def get_align_data(uid, prefix):
...
@@ -277,43 +263,45 @@ def get_align_data(uid, prefix):
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
def
get_masked_mel_boundary
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
span_tobe_replaced
):
mfa_end
:
List
[
float
],
fs
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
align_start
=
paddle
.
floor
(
fs
*
align_start
/
hop_length
).
int
()
align_start
=
paddle
.
floor
(
fs
*
align_start
/
hop_length
).
int
()
align_end
=
paddle
.
floor
(
fs
*
align_end
/
hop_length
).
int
()
align_end
=
paddle
.
floor
(
fs
*
align_end
/
hop_length
).
int
()
if
span_to
be_replaced
[
0
]
>=
len
(
mfa_start
):
if
span_to
_repl
[
0
]
>=
len
(
mfa_start
):
span_b
oundar
y
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()[
-
1
]]
span_b
d
y
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()[
-
1
]]
else
:
else
:
span_b
oundar
y
=
[
span_b
d
y
=
[
align_start
[
0
].
tolist
()[
span_to
be_replaced
[
0
]],
align_start
[
0
].
tolist
()[
span_to
_repl
[
0
]],
align_end
[
0
].
tolist
()[
span_to
be_replaced
[
1
]
-
1
]
align_end
[
0
].
tolist
()[
span_to
_repl
[
1
]
-
1
]
]
]
return
span_b
oundar
y
return
span_b
d
y
def
recover_dict
(
word2phns
,
tp_word2phns
):
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]
):
dic
=
{}
dic
=
{}
need_del_key
=
[]
keys_to_del
=
[]
exist_i
nde
x
=
[]
exist_i
d
x
=
[]
sp_count
=
0
sp_count
=
0
add_sp_count
=
0
add_sp_count
=
0
for
key
in
word2phns
.
keys
():
for
key
in
word2phns
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
if
wrd
==
'sp'
:
sp_count
+=
1
sp_count
+=
1
exist_i
nde
x
.
append
(
int
(
idx
))
exist_i
d
x
.
append
(
int
(
idx
))
else
:
else
:
need_del_key
.
append
(
key
)
keys_to_del
.
append
(
key
)
for
key
in
need_del_key
:
for
key
in
keys_to_del
:
del
word2phns
[
key
]
del
word2phns
[
key
]
cur_id
=
0
cur_id
=
0
for
key
in
tp_word2phns
.
keys
():
for
key
in
tp_word2phns
.
keys
():
# print("debug: ",key)
if
cur_id
in
exist_idx
:
if
cur_id
in
exist_index
:
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
dic
[
str
(
cur_id
)
+
"_sp"
]
=
'sp'
cur_id
+=
1
cur_id
+=
1
add_sp_count
+=
1
add_sp_count
+=
1
...
@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns):
...
@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns):
return
dic
return
dic
def
get_phns_and_spans
(
wav_path
,
old_str
,
new_str
,
source_language
,
def
get_phns_and_spans
(
wav_path
:
str
,
clone_target_language
):
old_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
append_new_str
=
(
old_str
==
new_str
[:
len
(
old_str
)])
append_new_str
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
if
source_lang
uage
==
"english"
:
if
source_lang
==
"english"
:
times2
,
word2phns
=
alignment
(
wav_path
,
old_str
)
times2
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
uage
==
"chinese"
:
elif
source_lang
==
"chinese"
:
times2
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
times2
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
...
@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
...
@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
else
:
assert
source_lang
uage
==
"chinese"
or
source_language
==
"english"
,
"source_language
is wrong..."
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
"source_lang
is wrong..."
for
item
in
times2
:
for
item
in
times2
:
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
mfa_end
.
append
(
float
(
item
[
2
]))
old_phns
.
append
(
item
[
0
])
old_phns
.
append
(
item
[
0
])
if
append_new_str
and
(
source_lang
uage
!=
clone_target_language
):
if
append_new_str
and
(
source_lang
!=
target_lang
):
is_cross_lingual_clone
=
True
is_cross_lingual_clone
=
True
else
:
else
:
is_cross_lingual_clone
=
False
is_cross_lingual_clone
=
False
...
@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
...
@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_str_origin
=
new_str
[:
len
(
old_str
)]
new_str_origin
=
new_str
[:
len
(
old_str
)]
new_str_append
=
new_str
[
len
(
old_str
):]
new_str_append
=
new_str
[
len
(
old_str
):]
if
clone_target_language
==
"chinese"
:
if
target_lang
==
"chinese"
:
new_phns_origin
,
new_origin_word2phns
=
words2phns
(
new_str_origin
)
new_phns_origin
,
new_origin_word2phns
=
words2phns
(
new_str_origin
)
new_phns_append
,
temp_new_append_word2phns
=
words2phns_zh
(
new_phns_append
,
temp_new_append_word2phns
=
words2phns_zh
(
new_str_append
)
new_str_append
)
elif
clone_target_language
==
"english"
:
elif
target_lang
==
"english"
:
# 原始句子
new_phns_origin
,
new_origin_word2phns
=
words2phns_zh
(
new_phns_origin
,
new_origin_word2phns
=
words2phns_zh
(
new_str_origin
)
# 原始句子
new_str_origin
)
# clone句子
new_phns_append
,
temp_new_append_word2phns
=
words2phns
(
new_phns_append
,
temp_new_append_word2phns
=
words2phns
(
new_str_append
)
# clone句子
new_str_append
)
else
:
else
:
assert
clone_target_language
==
"chinese"
or
clone_target_language
==
"english"
,
"cloning is not support for this language, please check it."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
new_phns
=
new_phns_origin
+
new_phns_append
new_phns
=
new_phns_origin
+
new_phns_append
...
@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
...
@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_append_word2phns
.
items
()))
new_append_word2phns
.
items
()))
else
:
else
:
if
source_lang
uage
==
clone_target_language
and
clone_target_language
==
"english"
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
new_phns
,
new_word2phns
=
words2phns
(
new_str
)
new_phns
,
new_word2phns
=
words2phns
(
new_str
)
elif
source_lang
uage
==
clone_target_language
and
clone_target_language
==
"chinese"
:
elif
source_lang
==
target_lang
and
target_lang
==
"chinese"
:
new_phns
,
new_word2phns
=
words2phns_zh
(
new_str
)
new_phns
,
new_word2phns
=
words2phns_zh
(
new_str
)
else
:
else
:
assert
source_language
==
clone_target_language
,
"source language is not same with target language..."
assert
source_lang
==
target_lang
,
\
"source language is not same with target language..."
span_to
be_replaced
=
[
0
,
len
(
old_phns
)
-
1
]
span_to
_repl
=
[
0
,
len
(
old_phns
)
-
1
]
span_to
be_adde
d
=
[
0
,
len
(
new_phns
)
-
1
]
span_to
_ad
d
=
[
0
,
len
(
new_phns
)
-
1
]
left_i
nde
x
=
0
left_i
d
x
=
0
new_phns_left
=
[]
new_phns_left
=
[]
sp_count
=
0
sp_count
=
0
# find the left different index
# find the left different index
...
@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
...
@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
else
:
else
:
idx
=
str
(
int
(
idx
)
-
sp_count
)
idx
=
str
(
int
(
idx
)
-
sp_count
)
if
idx
+
'_'
+
wrd
in
new_word2phns
:
if
idx
+
'_'
+
wrd
in
new_word2phns
:
left_i
nde
x
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
left_i
d
x
+=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_left
.
extend
(
word2phns
[
key
].
split
())
new_phns_left
.
extend
(
word2phns
[
key
].
split
())
else
:
else
:
span_to
be_replaced
[
0
]
=
len
(
new_phns_left
)
span_to
_repl
[
0
]
=
len
(
new_phns_left
)
span_to
be_adde
d
[
0
]
=
len
(
new_phns_left
)
span_to
_ad
d
[
0
]
=
len
(
new_phns_left
)
break
break
# reverse word2phns and new_word2phns
# reverse word2phns and new_word2phns
right_i
nde
x
=
0
right_i
d
x
=
0
new_phns_right
=
[]
new_phns_right
=
[]
sp_count
=
0
sp_count
=
0
word2phns_max_i
nde
x
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
word2phns_max_i
d
x
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_word2phns_max_i
nde
x
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_word2phns_max_i
d
x
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
])
new_phns_mid
dle
=
[]
new_phns_mid
=
[]
if
append_new_str
:
if
append_new_str
:
new_phns_right
=
[]
new_phns_right
=
[]
new_phns_mid
dle
=
new_phns
[
left_inde
x
:]
new_phns_mid
=
new_phns
[
left_id
x
:]
span_to
be_replaced
[
0
]
=
len
(
new_phns_left
)
span_to
_repl
[
0
]
=
len
(
new_phns_left
)
span_to
be_adde
d
[
0
]
=
len
(
new_phns_left
)
span_to
_ad
d
[
0
]
=
len
(
new_phns_left
)
span_to
be_added
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_middle
)
span_to
_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to
be_replaced
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
span_to
_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
else
:
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
idx
,
wrd
=
key
.
split
(
'_'
)
...
@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
...
@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
sp_count
+=
1
sp_count
+=
1
new_phns_right
=
[
'sp'
]
+
new_phns_right
new_phns_right
=
[
'sp'
]
+
new_phns_right
else
:
else
:
idx
=
str
(
new_word2phns_max_i
ndex
-
(
word2phns_max_index
-
int
(
idx
=
str
(
new_word2phns_max_i
dx
-
(
word2phns_max_idx
-
int
(
idx
)
idx
)
-
sp_count
))
-
sp_count
))
if
idx
+
'_'
+
wrd
in
new_word2phns
:
if
idx
+
'_'
+
wrd
in
new_word2phns
:
right_i
nde
x
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
right_i
d
x
-=
len
(
new_word2phns
[
idx
+
'_'
+
wrd
])
new_phns_right
=
word2phns
[
key
].
split
()
+
new_phns_right
new_phns_right
=
word2phns
[
key
].
split
()
+
new_phns_right
else
:
else
:
span_tobe_replaced
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_middle
=
new_phns
[
left_index
:
right_index
]
new_phns_mid
=
new_phns
[
left_idx
:
right_idx
]
span_tobe_added
[
1
]
=
len
(
new_phns_left
)
+
len
(
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
new_phns_middle
)
if
len
(
new_phns_mid
)
==
0
:
if
len
(
new_phns_middle
)
==
0
:
span_to_add
[
1
]
=
min
(
span_to_add
[
1
]
+
1
,
len
(
new_phns
))
span_tobe_added
[
1
]
=
min
(
span_tobe_added
[
1
]
+
1
,
span_to_add
[
0
]
=
max
(
0
,
span_to_add
[
0
]
-
1
)
len
(
new_phns
))
span_to_repl
[
0
]
=
max
(
0
,
span_to_repl
[
0
]
-
1
)
span_tobe_added
[
0
]
=
max
(
0
,
span_tobe_added
[
0
]
-
1
)
span_to_repl
[
1
]
=
min
(
span_to_repl
[
1
]
+
1
,
span_tobe_replaced
[
0
]
=
max
(
0
,
len
(
old_phns
))
span_tobe_replaced
[
0
]
-
1
)
span_tobe_replaced
[
1
]
=
min
(
span_tobe_replaced
[
1
]
+
1
,
len
(
old_phns
))
break
break
new_phns
=
new_phns_left
+
new_phns_mid
dle
+
new_phns_right
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to
be_replaced
,
span_tobe_adde
d
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to
_repl
,
span_to_ad
d
def
duration_adjust_factor
(
original_dur
,
pred_dur
,
phns
):
def
duration_adjust_factor
(
original_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
length
=
0
accumulate
=
0
factor_list
=
[]
factor_list
=
[]
for
ori
,
pred
,
phn
in
zip
(
original_dur
,
pred_dur
,
phns
):
for
ori
,
pred
,
phn
in
zip
(
original_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
if
pred
==
0
or
phn
==
'sp'
:
...
@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
...
@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
return
np
.
average
(
factor_list
[
length
:
-
length
])
return
np
.
average
(
factor_list
[
length
:
-
length
])
def
prepare_features_with_duration
(
uid
,
def
prepare_features_with_duration
(
uid
:
str
,
prefix
,
prefix
:
str
,
clone_uid
,
wav_path
:
str
,
clone_prefix
,
mlm_model
:
nn
.
Layer
,
source_language
,
source_lang
:
str
=
"English"
,
target_language
,
target_lang
:
str
=
"English"
,
mlm_model
,
old_str
:
str
=
""
,
old_str
,
new_str
:
str
=
""
,
new_str
,
duration_preditor_path
:
str
=
None
,
wav_path
,
sid
:
str
=
None
,
duration_preditor_path
,
mask_reconstruct
:
bool
=
False
,
sid
=
None
,
duration_adjust
:
bool
=
True
,
mask_reconstruct
=
False
,
start_end_sp
:
bool
=
False
,
duration_adjust
=
True
,
start_end_sp
=
False
,
train_args
=
None
):
train_args
=
None
):
wav_org
,
rate
=
librosa
.
load
(
wav_org
,
rate
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_tobe_replaced
,
span_tobe_added
=
get_phns_and_spans
(
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
,
old_str
,
new_str
,
source_language
,
target_language
)
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
)
if
start_end_sp
:
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
new_phns
=
new_phns
+
[
'sp'
]
if
target_language
==
"english"
:
if
target_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
target_lang
)
old_phns
,
target_language
=
target_language
)
elif
target_lang
uage
==
"chinese"
:
elif
target_lang
==
"chinese"
:
if
source_lang
uage
==
"english"
:
if
source_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
uage
=
source_language
)
old_phns
,
target_lang
=
source_lang
)
elif
source_lang
uage
==
"chinese"
:
elif
source_lang
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
uage
=
source_language
)
old_phns
,
target_lang
=
source_lang
)
else
:
else
:
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"calculate duration_predict is not support for this language..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"calculate duration_predict is not support for this language..."
original_old_durations
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
original_old_durations
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
new_phns
=
old_phns
span_to
be_added
=
span_tobe_replaced
span_to
_add
=
span_to_repl
d_factor_left
=
duration_adjust_factor
(
d_factor_left
=
duration_adjust_factor
(
original_old_durations
[:
span_tobe_replaced
[
0
]],
original_old_durations
[:
span_to_repl
[
0
]],
old_durations
[:
span_tobe_replaced
[
0
]],
old_durations
[:
span_to_repl
[
0
]],
old_phns
[:
span_to_repl
[
0
]])
old_phns
[:
span_tobe_replaced
[
0
]])
d_factor_right
=
duration_adjust_factor
(
d_factor_right
=
duration_adjust_factor
(
original_old_durations
[
span_tobe_replaced
[
1
]:],
original_old_durations
[
span_to_repl
[
1
]:],
old_durations
[
span_tobe_replaced
[
1
]:],
old_durations
[
span_to_repl
[
1
]:],
old_phns
[
span_to_repl
[
1
]:])
old_phns
[
span_tobe_replaced
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
old_durations
]
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
old_durations
]
else
:
else
:
if
duration_adjust
:
if
duration_adjust
:
d_factor
=
duration_adjust_factor
(
original_old_durations
,
d_factor
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
old_durations
,
old_phns
)
d_factor_paddle
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
d_factor
=
d_factor
*
1.25
d_factor
=
d_factor
*
1.25
else
:
else
:
d_factor
=
1
d_factor
=
1
if
target_lang
uage
==
"english"
:
if
target_lang
==
"english"
:
new_durations
=
evaluate_durations
(
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
uage
=
target_language
)
new_phns
,
target_lang
=
target_lang
)
elif
target_lang
uage
==
"chinese"
:
elif
target_lang
==
"chinese"
:
new_durations
=
evaluate_durations
(
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
uage
=
target_language
)
new_phns
,
target_lang
=
target_lang
)
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
new_durations
]
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
new_durations
]
if
span_tobe_replaced
[
0
]
<
len
(
old_phns
)
and
old_phns
[
if
span_to_repl
[
0
]
<
len
(
old_phns
)
and
old_phns
[
span_to_repl
[
span_tobe_replaced
[
0
]]
==
new_phns
[
span_tobe_added
[
0
]]:
0
]]
==
new_phns
[
span_to_add
[
0
]]:
new_durations_adjusted
[
span_tobe_added
[
0
]]
=
original_old_durations
[
new_durations_adjusted
[
span_to_add
[
0
]]
=
original_old_durations
[
span_tobe_replaced
[
0
]]
span_to_repl
[
0
]]
if
span_tobe_replaced
[
1
]
<
len
(
old_phns
)
and
span_tobe_added
[
1
]
<
len
(
if
span_to_repl
[
1
]
<
len
(
old_phns
)
and
span_to_add
[
1
]
<
len
(
new_phns
):
new_phns
):
if
old_phns
[
span_to_repl
[
1
]]
==
new_phns
[
span_to_add
[
1
]]:
if
old_phns
[
span_tobe_replaced
[
1
]]
==
new_phns
[
span_tobe_added
[
1
]]:
new_durations_adjusted
[
span_to_add
[
1
]]
=
original_old_durations
[
new_durations_adjusted
[
span_tobe_added
[
span_to_repl
[
1
]]
1
]]
=
original_old_durations
[
span_tobe_replaced
[
1
]]
new_span_duration_sum
=
sum
(
new_span_duration_sum
=
sum
(
new_durations_adjusted
[
span_to
be_added
[
0
]:
span_tobe_adde
d
[
1
]])
new_durations_adjusted
[
span_to
_add
[
0
]:
span_to_ad
d
[
1
]])
old_span_duration_sum
=
sum
(
old_span_duration_sum
=
sum
(
original_old_durations
[
span_to
be_replaced
[
0
]:
span_tobe_replaced
[
1
]])
original_old_durations
[
span_to
_repl
[
0
]:
span_to_repl
[
1
]])
duration_offset
=
new_span_duration_sum
-
old_span_duration_sum
duration_offset
=
new_span_duration_sum
-
old_span_duration_sum
new_mfa_start
=
mfa_start
[:
span_to
be_replaced
[
0
]]
new_mfa_start
=
mfa_start
[:
span_to
_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to
be_replaced
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to
_repl
[
0
]]
for
i
in
new_durations_adjusted
[
span_to
be_added
[
0
]:
span_tobe_adde
d
[
1
]]:
for
i
in
new_durations_adjusted
[
span_to
_add
[
0
]:
span_to_ad
d
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
new_mfa_end
.
append
(
i
)
else
:
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
new_mfa_start
+=
[
i
+
duration_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
i
+
duration_offset
for
i
in
mfa_start
[
span_tobe_replaced
[
1
]:]
new_mfa_end
+=
[
i
+
duration_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
]
new_mfa_end
+=
[
i
+
duration_offset
for
i
in
mfa_end
[
span_tobe_replaced
[
1
]:]
]
# 3. get new wav
# 3. get new wav
if
span_to
be_replaced
[
0
]
>=
len
(
mfa_start
):
if
span_to
_repl
[
0
]
>=
len
(
mfa_start
):
left_i
nde
x
=
len
(
wav_org
)
left_i
d
x
=
len
(
wav_org
)
right_i
ndex
=
left_inde
x
right_i
dx
=
left_id
x
else
:
else
:
left_i
ndex
=
int
(
np
.
floor
(
mfa_start
[
span_tobe_replaced
[
0
]]
*
fs
))
left_i
dx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_i
ndex
=
int
(
np
.
ceil
(
mfa_end
[
span_tobe_replaced
[
1
]
-
1
]
*
fs
))
right_i
dx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
new_blank_wav
=
np
.
zeros
(
new_blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_duration_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
(
int
(
np
.
ceil
(
new_span_duration_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
new_wav_org
=
np
.
concatenate
(
new_wav_org
=
np
.
concatenate
(
[
wav_org
[:
left_i
ndex
],
new_blank_wav
,
wav_org
[
right_inde
x
:]])
[
wav_org
[:
left_i
dx
],
new_blank_wav
,
wav_org
[
right_id
x
:]])
# 4. get old and new mel span to be mask
# 4. get old and new mel span to be mask
old_span_boundary
=
get_masked_mel_boundary
(
# [92, 92]
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_tobe_replaced
)
# [92, 92]
old_span_bdy
=
get_masked_mel_bdy
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
new_span_boundary
=
get_masked_mel_boundary
(
new_mfa_start
,
new_mfa_end
,
fs
,
span_to_repl
)
hop_length
,
# [92, 174]
span_tobe_added
)
# [92, 174]
new_span_bdy
=
get_masked_mel_bdy
(
new_mfa_start
,
new_mfa_end
,
fs
,
hop_length
,
span_to_add
)
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_boundary
,
new_span_boundary
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prepare_features
(
uid
,
prefix
,
def
prepare_features
(
uid
:
str
,
clone_uid
,
mlm_model
:
nn
.
Layer
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
processor
,
wav_path
,
wav_path
:
str
,
old_str
,
prefix
:
str
=
"./prompt/dev/"
,
new_str
,
source_lang
:
str
=
"english"
,
duration_preditor_path
,
target_lang
:
str
=
"english"
,
sid
=
None
,
old_str
:
str
=
""
,
duration_adjust
=
True
,
new_str
:
str
=
""
,
start_end_sp
=
False
,
duration_preditor_path
:
str
=
None
,
mask_reconstruct
=
False
,
sid
:
str
=
None
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
train_args
=
None
):
train_args
=
None
):
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_boundary
,
new_span_boundary
=
prepare_features_with_duration
(
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prepare_features_with_duration
(
uid
,
uid
=
uid
,
prefix
,
prefix
=
prefix
,
clone_uid
,
source_lang
=
source_lang
,
clone_prefix
,
target_lang
=
target_lang
,
source_language
,
mlm_model
=
mlm_model
,
target_language
,
old_str
=
old_str
,
mlm_model
,
new_str
=
new_str
,
old_str
,
wav_path
=
wav_path
,
new_str
,
duration_preditor_path
=
duration_preditor_path
,
wav_path
,
duration_preditor_path
,
sid
=
sid
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
mask_reconstruct
=
mask_reconstruct
,
train_args
=
train_args
)
train_args
=
train_args
)
speech
=
np
.
array
(
wav_org
,
dtype
=
np
.
float32
)
speech
=
wav_org
align_start
=
np
.
array
(
mfa_start
)
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
align_end
=
np
.
array
(
mfa_end
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
text
=
np
.
array
(
text
=
np
.
array
(
list
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
# print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_bdy
=
np
.
array
(
new_span_bdy
)
span_boundary
=
np
.
array
(
new_span_boundary
)
batch
=
[(
'1'
,
{
batch
=
[(
'1'
,
{
"speech"
:
speech
,
"speech"
:
speech
,
"align_start"
:
align_start
,
"align_start"
:
align_start
,
"align_end"
:
align_end
,
"align_end"
:
align_end
,
"text"
:
text
,
"text"
:
text
,
"span_b
oundary"
:
span_boundar
y
"span_b
dy"
:
span_bd
y
})]
})]
return
batch
,
old_span_b
oundary
,
new_span_boundar
y
return
batch
,
old_span_b
dy
,
new_span_bd
y
def
decode_with_model
(
uid
,
def
decode_with_model
(
uid
:
str
,
prefix
,
mlm_model
:
nn
.
Layer
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
processor
,
collate_fn
,
collate_fn
,
wav_path
,
wav_path
:
str
,
old_str
,
prefix
:
str
=
"./prompt/dev/"
,
new_str
,
source_lang
:
str
=
"english"
,
duration_preditor_path
,
target_lang
:
str
=
"english"
,
sid
=
None
,
old_str
:
str
=
""
,
decoder
=
False
,
new_str
:
str
=
""
,
use_teacher_forcing
=
False
,
duration_preditor_path
:
str
=
None
,
duration_adjust
=
True
,
sid
:
str
=
None
,
start_end_sp
=
False
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
train_args
=
None
):
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
batch
,
old_span_boundary
,
new_span_boundary
=
prepare_features
(
batch
,
old_span_bdy
,
new_span_bdy
=
prepare_features
(
uid
,
uid
=
uid
,
prefix
,
prefix
=
prefix
,
clone_uid
,
source_lang
=
source_lang
,
clone_prefix
,
target_lang
=
target_lang
,
source_language
,
mlm_model
=
mlm_model
,
target_language
,
processor
=
processor
,
mlm_model
,
wav_path
=
wav_path
,
processor
,
old_str
=
old_str
,
wav_path
,
new_str
=
new_str
,
old_str
,
duration_preditor_path
=
duration_preditor_path
,
new_str
,
sid
=
sid
,
duration_preditor_path
,
sid
,
duration_adjust
=
duration_adjust
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
train_args
=
train_args
)
feats
=
collate_fn
(
batch
)[
1
]
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos
ition
'
in
feats
.
keys
():
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos
ition
'
)
feats
.
pop
(
'text_masked_pos'
)
for
k
,
v
in
feats
.
items
():
for
k
,
v
in
feats
.
items
():
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
rtn
=
mlm_model
.
inference
(
rtn
=
mlm_model
.
inference
(
**
feats
,
**
feats
,
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
span_boundary
=
new_span_boundary
,
use_teacher_forcing
=
use_teacher_forcing
)
output
=
rtn
[
'feat_gen'
]
output
=
rtn
[
'feat_gen'
]
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output_feat
=
paddle
.
concat
(
...
@@ -731,12 +706,9 @@ def decode_with_model(uid,
...
@@ -731,12 +706,9 @@ def decode_with_model(uid,
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
axis
=
0
).
cpu
()
axis
=
0
).
cpu
()
wav_org
,
rate
=
librosa
.
load
(
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
origin_speech
=
paddle
.
to_tensor
(
return
wav_org
,
None
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
np
.
array
(
wav_org
,
dtype
=
np
.
float32
)).
unsqueeze
(
0
)
speech_lengths
=
paddle
.
to_tensor
(
len
(
wav_org
)).
unsqueeze
(
0
)
return
wav_org
,
None
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
class
MLMCollateFn
:
class
MLMCollateFn
:
...
@@ -800,33 +772,15 @@ def mlm_collate_fn(
...
@@ -800,33 +772,15 @@ def mlm_collate_fn(
sega_emb
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
uttids
=
[
u
for
u
,
_
in
data
]
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_len
gth
s"
)
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_len
gth
s is reserved:
{
list
(
data
[
0
])
}
"
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
output
=
{}
for
key
in
data
[
0
]:
for
key
in
data
[
0
]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
...
@@ -846,37 +800,35 @@ def mlm_collate_fn(
...
@@ -846,37 +800,35 @@ def mlm_collate_fn(
# lens: (Batch,)
# lens: (Batch,)
if
key
not
in
not_sequence
:
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
long
)
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_len
gth
s"
]
=
lens
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats
=
paddle
.
to_tensor
(
feats
)
# print('out shape', paddle.shape(feats))
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats_lengths
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
batch_size
=
paddle
.
shape
(
feats
)[
0
]
if
'text'
not
in
output
:
if
'text'
not
in
output
:
text
=
paddle
.
zeros
_like
(
feats_lengths
.
unsqueeze
(
-
1
))
-
2
text
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
.
unsqueeze
(
-
1
)
))
-
2
text_len
gths
=
paddle
.
zeros_like
(
feats_lengths
)
+
1
text_len
s
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
)
)
+
1
max_tlen
=
1
max_tlen
=
1
align_start
=
paddle
.
zeros_like
(
text
)
align_start
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_end
=
paddle
.
zeros_like
(
text
)
align_end
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_start_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
align_start_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
align_end_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
sega_emb
=
False
sega_emb
=
False
mean_phn_span
=
0
mean_phn_span
=
0
mlm_prob
=
0.15
mlm_prob
=
0.15
else
:
else
:
text
,
text_lengths
=
output
[
"text"
],
output
[
"text_lengths"
]
text
=
output
[
"text"
]
align_start
,
align_start_lengths
,
align_end
,
align_end_lengths
=
output
[
text_lens
=
output
[
"text_lens"
]
"align_start"
],
output
[
"align_start_lengths"
],
output
[
align_start
=
output
[
"align_start"
]
"align_end"
],
output
[
"align_end_lengths"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
feats_extract
.
hop_length
).
int
()
feats_extract
.
hop_length
).
int
()
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
feats_extract
.
hop_length
).
int
()
feats_extract
.
hop_length
).
int
()
max_tlen
=
max
(
text_len
gths
).
item
(
)
max_tlen
=
max
(
text_len
s
)
max_slen
=
max
(
feats_len
gths
).
item
(
)
max_slen
=
max
(
feats_len
s
)
speech_pad
=
feats
[:,
:
max_slen
]
speech_pad
=
feats
[:,
:
max_slen
]
if
attention_window
>
0
and
pad_speech
:
if
attention_window
>
0
and
pad_speech
:
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
...
@@ -888,51 +840,49 @@ def mlm_collate_fn(
...
@@ -888,51 +840,49 @@ def mlm_collate_fn(
else
:
else
:
text_pad
=
text
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_mask
=
make_non_pad_mask
(
text_len
gths
.
tolist
()
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
text_len
s
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
if
attention_window
>
0
:
if
attention_window
>
0
:
text_mask
=
text_mask
*
2
text_mask
=
text_mask
*
2
speech_mask
=
make_non_pad_mask
(
speech_mask
=
make_non_pad_mask
(
feats_len
gths
.
tolist
()
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
feats_len
s
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_b
oundar
y
=
None
span_b
d
y
=
None
if
'span_b
oundar
y'
in
output
.
keys
():
if
'span_b
d
y'
in
output
.
keys
():
span_b
oundary
=
output
[
'span_boundar
y'
]
span_b
dy
=
output
[
'span_bd
y'
]
if
text_masking
:
if
text_masking
:
masked_pos
ition
,
text_masked_position
,
_
=
phones_text_masking
(
masked_pos
,
text_masked_pos
,
_
=
phones_text_masking
(
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
span_boundary
)
else
:
else
:
text_masked_pos
ition
=
np
.
zeros
(
text_pad
.
size
(
))
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
masked_pos
ition
,
_
=
phones_masking
(
masked_pos
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_end
,
align_start_lens
,
mlm_prob
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundar
y
)
mean_phn_span
,
span_bd
y
)
output_dict
=
{}
output_dict
=
{}
if
duration_collect
and
'text'
in
output
:
if
duration_collect
and
'text'
in
output
:
reordered_i
ndex
,
speech_segment_pos
,
text_segment_pos
,
durations
,
feats_lengths
=
get_segment
_pos_reduce_duration
(
reordered_i
dx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
=
get_seg
_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_len
gth
s
,
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
,
masked_pos
ition
,
feats_length
s
)
sega_emb
,
masked_pos
,
feats_len
s
)
speech_mask
=
make_non_pad_mask
(
speech_mask
=
make_non_pad_mask
(
feats_lengths
.
tolist
(),
feats_lens
,
speech_pad
[:,
:
reordered_idx
.
shape
[
1
],
0
],
speech_pad
[:,
:
reordered_index
.
shape
[
1
],
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
length_dim
=
1
).
unsqueeze
(
-
2
)
output_dict
[
'durations'
]
=
durations
output_dict
[
'durations'
]
=
durations
output_dict
[
'reordered_i
ndex'
]
=
reordered_inde
x
output_dict
[
'reordered_i
dx'
]
=
reordered_id
x
else
:
else
:
speech_seg
ment_pos
,
text_segment_pos
=
get_segment_pos
(
speech_seg
_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
,
text_pad
,
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
align_start
,
align_end
,
sega_emb
)
align_start_lens
,
sega_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos
ition'
]
=
masked_position
output_dict
[
'masked_pos
'
]
=
masked_pos
output_dict
[
'text_masked_pos
ition'
]
=
text_masked_position
output_dict
[
'text_masked_pos
'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg
ment_pos'
]
=
speech_segment
_pos
output_dict
[
'speech_seg
_pos'
]
=
speech_seg
_pos
output_dict
[
'text_seg
ment_pos'
]
=
text_segment
_pos
output_dict
[
'text_seg
_pos'
]
=
text_seg
_pos
output_dict
[
'speech_len
gths'
]
=
output
[
"speech_length
s"
]
output_dict
[
'speech_len
s'
]
=
output
[
"speech_len
s"
]
output_dict
[
'text_len
gths'
]
=
text_length
s
output_dict
[
'text_len
s'
]
=
text_len
s
output
=
(
uttids
,
output_dict
)
output
=
(
uttids
,
output_dict
)
return
output
return
output
...
@@ -940,13 +890,13 @@ def mlm_collate_fn(
...
@@ -940,13 +890,13 @@ def mlm_collate_fn(
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
# -> Callable[
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str,
torch.
Tensor]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# ]:
# assert check_argument_types()
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class
=
LogMelFBank
feats_extract_class
=
LogMelFBank
if
args
.
feats_extract_conf
[
'win_length'
]
==
None
:
if
args
.
feats_extract_conf
[
'win_length'
]
is
None
:
args
.
feats_extract_conf
[
'win_length'
]
=
args
.
feats_extract_conf
[
'n_fft'
]
args
.
feats_extract_conf
[
'win_length'
]
=
args
.
feats_extract_conf
[
'n_fft'
]
args_dic
=
{}
args_dic
=
{}
...
@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
...
@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
args_dic
[
'sr'
]
=
v
args_dic
[
'sr'
]
=
v
else
:
else
:
args_dic
[
k
]
=
v
args_dic
[
k
]
=
v
# feats_extract = feats_extract_class(**args.feats_extract_conf)
feats_extract
=
feats_extract_class
(
**
args_dic
)
feats_extract
=
feats_extract_class
(
**
args_dic
)
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
...
@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
...
@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
if
epoch
==
-
1
:
if
epoch
==
-
1
:
mlm_prob_factor
=
1
mlm_prob_factor
=
1
else
:
else
:
mlm_probs
=
[
1.0
,
1.0
,
0.7
,
0.6
,
0.5
]
mlm_prob_factor
=
0.8
mlm_prob_factor
=
0.8
#mlm_probs[epoch // 100]
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
(
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
(
)
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
)
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
duration_collect
=
True
duration_collect
=
True
...
@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
...
@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
duration_collect
=
duration_collect
)
duration_collect
=
duration_collect
)
def
get_mlm_output
(
uid
,
def
get_mlm_output
(
uid
:
str
,
prefix
,
wav_path
:
str
,
clone_uid
,
prefix
:
str
=
"./prompt/dev/"
,
clone_prefix
,
model_name
:
str
=
"conformer"
,
source_language
,
source_lang
:
str
=
"english"
,
target_language
,
target_lang
:
str
=
"english"
,
model_name
,
old_str
:
str
=
""
,
wav_path
,
new_str
:
str
=
""
,
old_str
,
duration_preditor_path
:
str
=
None
,
new_str
,
sid
:
str
=
None
,
duration_preditor_path
,
decoder
:
bool
=
False
,
sid
=
None
,
use_teacher_forcing
:
bool
=
False
,
decoder
=
False
,
duration_adjust
:
bool
=
True
,
use_teacher_forcing
=
False
,
start_end_sp
:
bool
=
False
):
dynamic_eval
=
(
0
,
0
),
duration_adjust
=
True
,
start_end_sp
=
False
):
mlm_model
,
train_args
=
load_model
(
model_name
)
mlm_model
,
train_args
=
load_model
(
model_name
)
mlm_model
.
eval
()
mlm_model
.
eval
()
processor
=
None
processor
=
None
collate_fn
=
build_collate_fn
(
train_args
,
False
)
collate_fn
=
build_collate_fn
(
train_args
,
False
)
return
decode_with_model
(
return
decode_with_model
(
uid
,
uid
=
uid
,
prefix
,
prefix
=
prefix
,
clone_uid
,
source_lang
=
source_lang
,
clone_prefix
,
target_lang
=
target_lang
,
source_language
,
mlm_model
=
mlm_model
,
target_language
,
processor
=
processor
,
mlm_model
,
collate_fn
=
collate_fn
,
processor
,
wav_path
=
wav_path
,
collate_fn
,
old_str
=
old_str
,
wav_path
,
new_str
=
new_str
,
old_str
,
duration_preditor_path
=
duration_preditor_path
,
new_str
,
duration_preditor_path
,
sid
=
sid
,
sid
=
sid
,
decoder
=
decoder
,
decoder
=
decoder
,
use_teacher_forcing
=
use_teacher_forcing
,
use_teacher_forcing
=
use_teacher_forcing
,
...
@@ -1033,23 +976,20 @@ def get_mlm_output(uid,
...
@@ -1033,23 +976,20 @@ def get_mlm_output(uid,
train_args
=
train_args
)
train_args
=
train_args
)
def
test_vctk
(
uid
,
def
evaluate
(
uid
:
str
,
clone_uid
,
source_lang
:
str
=
"english"
,
clone_prefix
,
target_lang
:
str
=
"english"
,
source_language
,
use_pt_vocoder
:
bool
=
False
,
target_language
,
prefix
:
str
=
"./prompt/dev/"
,
vocoder
,
model_name
:
str
=
"conformer"
,
prefix
=
'dump/raw/dev'
,
old_str
:
str
=
""
,
model_name
=
"conformer"
,
new_str
:
str
=
""
,
old_str
=
""
,
prompt_decoding
:
bool
=
False
,
new_str
=
""
,
task_name
:
str
=
None
):
prompt_decoding
=
False
,
dynamic_eval
=
(
0
,
0
),
task_name
=
None
):
duration_preditor_path
=
None
duration_preditor_path
=
None
spemd
=
None
spemd
=
None
full_origin_str
,
wav_path
=
read_data
(
uid
,
prefix
)
full_origin_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
if
task_name
==
'edit'
:
new_str
=
new_str
new_str
=
new_str
...
@@ -1065,19 +1005,17 @@ def test_vctk(uid,
...
@@ -1065,19 +1005,17 @@ def test_vctk(uid,
old_str
=
full_origin_str
old_str
=
full_origin_str
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
uid
,
uid
=
uid
,
prefix
,
prefix
=
prefix
,
clone_uid
,
source_lang
=
source_lang
,
clone_prefix
,
target_lang
=
target_lang
,
source_language
,
model_name
=
model_name
,
target_language
,
wav_path
=
wav_path
,
model_name
,
full_origin_str
=
full_origin_str
,
wav_path
,
old_str
=
old_str
,
full_origin_str
,
new_str
=
new_str
,
old_str
,
use_pt_vocoder
=
use_pt_vocoder
,
new_str
,
duration_preditor_path
=
duration_preditor_path
,
vocoder
,
duration_preditor_path
,
sid
=
spemd
)
sid
=
spemd
)
return
results_dict
return
results_dict
...
@@ -1086,17 +1024,14 @@ if __name__ == "__main__":
...
@@ -1086,17 +1024,14 @@ if __name__ == "__main__":
# parse config and args
# parse config and args
args
=
parse_args
()
args
=
parse_args
()
data_dict
=
test_vctk
(
data_dict
=
evaluate
(
args
.
uid
,
uid
=
args
.
uid
,
args
.
clone_uid
,
source_lang
=
args
.
source_lang
,
args
.
clone_prefix
,
target_lang
=
args
.
target_lang
,
args
.
source_language
,
use_pt_vocoder
=
args
.
use_pt_vocoder
,
args
.
target_language
,
prefix
=
args
.
prefix
,
args
.
use_pt_vocoder
,
model_name
=
args
.
model_name
,
args
.
prefix
,
args
.
model_name
,
new_str
=
args
.
new_str
,
new_str
=
args
.
new_str
,
task_name
=
args
.
task_name
)
task_name
=
args
.
task_name
)
sf
.
write
(
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
sf
.
write
(
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
print
(
"finished..."
)
print
(
"finished..."
)
# exit()
ernie-sat/model_paddle.py
浏览文件 @
b81832ce
...
@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
...
@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
def
forward
(
self
,
input
:
paddle
.
Tensor
,
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
=
None
)
->
paddle
.
Tensor
:
masked_position
=
None
)
->
paddle
.
Tensor
:
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_position
=
paddle
.
expand_as
(
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
paddle
.
unsqueeze
(
masked_position
,
-
1
),
input
)
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
masked_input
=
masked_fill
(
input
,
masked_position
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_position
,
0
)
return
masked_input
return
masked_input
...
@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
...
@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
def
forward
(
self
,
def
forward
(
self
,
speech_pad
,
speech_pad
,
text_pad
,
text_pad
,
masked_pos
ition
,
masked_pos
,
speech_mask
=
None
,
speech_mask
=
None
,
text_mask
=
None
,
text_mask
=
None
,
speech_seg
ment
_pos
=
None
,
speech_seg_pos
=
None
,
text_seg
ment
_pos
=
None
):
text_seg_pos
=
None
):
"""Encode input sequence.
"""Encode input sequence.
"""
"""
if
masked_pos
ition
is
not
None
:
if
masked_pos
is
not
None
:
speech_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
ition
)
speech_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
)
else
:
else
:
speech_pad
=
self
.
speech_embed
(
speech_pad
)
speech_pad
=
self
.
speech_embed
(
speech_pad
)
# pure speech input
# pure speech input
if
-
2
in
np
.
array
(
text_pad
):
if
-
2
in
np
.
array
(
text_pad
):
text_pad
=
text_pad
+
3
text_pad
=
text_pad
+
3
text_mask
=
paddle
.
unsqueeze
(
bool
(
text_pad
),
1
)
text_mask
=
paddle
.
unsqueeze
(
bool
(
text_pad
),
1
)
text_seg
ment
_pos
=
paddle
.
zeros_like
(
text_pad
)
text_seg_pos
=
paddle
.
zeros_like
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg
ment
_pos
),
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg_pos
),
text_pad
[
1
])
text_pad
[
1
])
text_seg
ment
_pos
=
None
text_seg_pos
=
None
elif
text_pad
is
not
None
:
elif
text_pad
is
not
None
:
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
segment_emb
=
None
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
if
speech_segment_pos
is
not
None
and
text_segment_pos
is
not
None
and
self
.
segment_emb
:
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
speech_segment_emb
=
self
.
segment_emb
(
speech_segment_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text_segment_emb
=
self
.
segment_emb
(
text_segment_pos
)
text_pad
=
(
text_pad
[
0
]
+
text_seg_emb
,
text_pad
[
1
])
text_pad
=
(
text_pad
[
0
]
+
text_segment_emb
,
text_pad
[
1
])
speech_pad
=
(
speech_pad
[
0
]
+
speech_seg_emb
,
speech_pad
[
1
])
speech_pad
=
(
speech_pad
[
0
]
+
speech_segment_emb
,
speech_pad
[
1
])
segment_emb
=
paddle
.
concat
(
[
speech_segment_emb
,
text_segment_emb
],
axis
=
1
)
if
self
.
pre_speech_encoders
:
if
self
.
pre_speech_encoders
:
speech_pad
,
_
=
self
.
pre_speech_encoders
(
speech_pad
,
speech_mask
)
speech_pad
,
_
=
self
.
pre_speech_encoders
(
speech_pad
,
speech_mask
)
...
@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
...
@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
if
self
.
normalize_before
:
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
xs
=
self
.
after_norm
(
xs
)
return
xs
,
masks
#, segment_emb
return
xs
,
masks
class
MLMDecoder
(
MLMEncoder
):
class
MLMDecoder
(
MLMEncoder
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
ition
=
None
,
segment_emb
=
None
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
=
None
,
segment_emb
=
None
):
"""Encode input sequence.
"""Encode input sequence.
Args:
Args:
...
@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
...
@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
paddle.Tensor: Mask tensor (#batch, time).
"""
"""
emb
,
mlm_position
=
None
,
None
if
not
self
.
training
:
if
not
self
.
training
:
masked_pos
ition
=
None
masked_pos
=
None
xs
=
self
.
embed
(
xs
)
xs
=
self
.
embed
(
xs
)
if
segment_emb
:
if
segment_emb
:
xs
=
(
xs
[
0
]
+
segment_emb
,
xs
[
1
])
xs
=
(
xs
[
0
]
+
segment_emb
,
xs
[
1
])
...
@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
...
@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
def
collect_feats
(
self
,
def
collect_feats
(
self
,
speech
,
speech
,
speech_len
gth
s
,
speech_lens
,
text
,
text
,
text_len
gth
s
,
text_lens
,
masked_pos
ition
,
masked_pos
,
speech_mask
,
speech_mask
,
text_mask
,
text_mask
,
speech_seg
ment
_pos
,
speech_seg_pos
,
text_seg
ment
_pos
,
text_seg_pos
,
y_masks
=
None
)
->
Dict
[
str
,
paddle
.
Tensor
]:
y_masks
=
None
)
->
Dict
[
str
,
paddle
.
Tensor
]:
return
{
"feats"
:
speech
,
"feats_len
gths"
:
speech_length
s
}
return
{
"feats"
:
speech
,
"feats_len
s"
:
speech_len
s
}
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
speech_pad_placeholder
=
batch
[
'speech_pad'
]
...
@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
...
@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
if
self
.
decoder
is
not
None
:
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
ys_in
,
y_masks
,
encoder_out
,
zs
,
_
=
self
.
decoder
(
ys_in
,
y_masks
,
encoder_out
,
bool
(
h_masks
),
bool
(
h_masks
),
self
.
encoder
.
segment_emb
(
speech_seg
ment
_pos
))
self
.
encoder
.
segment_emb
(
speech_seg_pos
))
speech_hidden_states
=
zs
speech_hidden_states
=
zs
else
:
else
:
speech_hidden_states
=
encoder_out
[:,
:
paddle
.
shape
(
batch
[
speech_hidden_states
=
encoder_out
[:,
:
paddle
.
shape
(
batch
[
...
@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
...
@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
else
:
else
:
after_outs
=
None
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos
ition
'
]
'masked_pos'
]
def
inference
(
def
inference
(
self
,
self
,
speech
,
speech
,
text
,
text
,
masked_pos
ition
,
masked_pos
,
speech_mask
,
speech_mask
,
text_mask
,
text_mask
,
speech_seg
ment
_pos
,
speech_seg_pos
,
text_seg
ment
_pos
,
text_seg_pos
,
span_b
oundar
y
,
span_b
d
y
,
y_masks
=
None
,
y_masks
=
None
,
speech_len
gth
s
=
None
,
speech_lens
=
None
,
text_len
gth
s
=
None
,
text_lens
=
None
,
feats
:
Optional
[
paddle
.
Tensor
]
=
None
,
feats
:
Optional
[
paddle
.
Tensor
]
=
None
,
spembs
:
Optional
[
paddle
.
Tensor
]
=
None
,
spembs
:
Optional
[
paddle
.
Tensor
]
=
None
,
sids
:
Optional
[
paddle
.
Tensor
]
=
None
,
sids
:
Optional
[
paddle
.
Tensor
]
=
None
,
...
@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
...
@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
batch
=
dict
(
batch
=
dict
(
speech_pad
=
speech
,
speech_pad
=
speech
,
text_pad
=
text
,
text_pad
=
text
,
masked_pos
ition
=
masked_position
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
text_mask
=
text_mask
,
speech_seg
ment_pos
=
speech_segment
_pos
,
speech_seg
_pos
=
speech_seg
_pos
,
text_seg
ment_pos
=
text_segment
_pos
,
)
text_seg
_pos
=
text_seg
_pos
,
)
# # inference with teacher forcing
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
# hs, h_masks = self.encoder(**batch)
outs
=
[
batch
[
'speech_pad'
][:,
:
span_b
oundar
y
[
0
]]]
outs
=
[
batch
[
'speech_pad'
][:,
:
span_b
d
y
[
0
]]]
z_cache
=
None
z_cache
=
None
if
use_teacher_forcing
:
if
use_teacher_forcing
:
before
,
zs
,
_
,
_
=
self
.
forward
(
before
,
zs
,
_
,
_
=
self
.
forward
(
batch
,
speech_seg
ment
_pos
,
y_masks
=
y_masks
)
batch
,
speech_seg_pos
,
y_masks
=
y_masks
)
if
zs
is
None
:
if
zs
is
None
:
zs
=
before
zs
=
before
outs
+=
[
zs
[
0
][
span_b
oundary
[
0
]:
span_boundar
y
[
1
]]]
outs
+=
[
zs
[
0
][
span_b
dy
[
0
]:
span_bd
y
[
1
]]]
outs
+=
[
batch
[
'speech_pad'
][:,
span_b
oundar
y
[
1
]:]]
outs
+=
[
batch
[
'speech_pad'
][:,
span_b
d
y
[
1
]:]]
return
dict
(
feat_gen
=
outs
)
return
dict
(
feat_gen
=
outs
)
return
None
return
None
...
@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
...
@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
class
MLMEncAsDecoderModel
(
MLMModel
):
class
MLMEncAsDecoderModel
(
MLMModel
):
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
speech_pad_placeholder
=
batch
[
'speech_pad'
]
...
@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
...
@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
else
:
else
:
after_outs
=
None
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos
ition
'
]
'masked_pos'
]
class
MLMDualMaksingModel
(
MLMModel
):
class
MLMDualMaksingModel
(
MLMModel
):
...
@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
...
@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
batch
):
batch
):
xs_pad
=
batch
[
'speech_pad'
]
xs_pad
=
batch
[
'speech_pad'
]
text_pad
=
batch
[
'text_pad'
]
text_pad
=
batch
[
'text_pad'
]
masked_pos
ition
=
batch
[
'masked_position
'
]
masked_pos
=
batch
[
'masked_pos
'
]
text_masked_pos
ition
=
batch
[
'text_masked_position
'
]
text_masked_pos
=
batch
[
'text_masked_pos
'
]
mlm_loss_pos
ition
=
masked_position
>
0
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
...
@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
...
@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
ition
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_position
)
+
1e-10
)
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text_pad
,
(
-
1
)))
*
paddle
.
reshape
(
paddle
.
reshape
(
text_pad
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_position
,
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
(
-
1
))))
/
paddle
.
sum
((
text_masked_position
)
+
1e-10
)
return
loss_mlm
,
loss_text
return
loss_mlm
,
loss_text
def
forward
(
self
,
batch
,
speech_seg
ment
_pos
,
y_masks
=
None
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
if
self
.
decoder
is
not
None
:
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
...
@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
...
@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
[
0
,
2
,
1
])
[
0
,
2
,
1
])
else
:
else
:
after_outs
=
None
after_outs
=
None
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos
ition'],batch['text_masked_position
']
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos
'],batch['text_masked_pos
']
def
build_model_from_file
(
config_file
,
model_file
):
def
build_model_from_file
(
config_file
,
model_file
):
...
...
ernie-sat/paddlespeech/t2s/modules/nets_utils.py
浏览文件 @
b81832ce
...
@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
...
@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
"""
"""
n_batch
=
len
(
xs
)
n_batch
=
len
(
xs
)
max_len
=
max
(
x
.
shape
[
0
]
for
x
in
xs
)
max_len
=
max
(
x
.
shape
[
0
]
for
x
in
xs
)
pad
=
paddle
.
full
([
n_batch
,
max_len
,
*
xs
[
0
].
shape
[
1
:]],
pad_value
)
pad
=
paddle
.
full
([
n_batch
,
max_len
,
*
xs
[
0
].
shape
[
1
:]],
pad_value
,
dtype
=
xs
[
0
].
dtype
)
for
i
in
range
(
n_batch
):
for
i
in
range
(
n_batch
):
pad
[
i
,
:
xs
[
i
].
shape
[
0
]]
=
xs
[
i
]
pad
[
i
,
:
xs
[
i
].
shape
[
0
]]
=
xs
[
i
]
...
@@ -46,13 +46,18 @@ def pad_list(xs, pad_value):
...
@@ -46,13 +46,18 @@ def pad_list(xs, pad_value):
return
pad
return
pad
def
make_pad_mask
(
lengths
,
length_dim
=-
1
):
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of padded part.
"""Make mask tensor containing indices of padded part.
Args:
Args:
lengths (Tensor(int64)): Batch of lengths (B,).
lengths (Tensor(int64)): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Returns:
Tensor(bool): Mask tensor containing indices of padded part bool.
Tensor(bool): Mask tensor containing indices of padded part bool.
Examples:
Examples:
...
@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1):
...
@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1):
>>> lengths = [5, 3, 2]
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]])
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]],)
"""
"""
if
length_dim
==
0
:
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
bs
=
paddle
.
shape
(
lengths
)[
0
]
bs
=
paddle
.
shape
(
lengths
)[
0
]
maxlen
=
lengths
.
max
()
if
xs
is
None
:
maxlen
=
lengths
.
max
()
else
:
maxlen
=
paddle
.
shape
(
xs
)[
length_dim
]
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
([
bs
,
maxlen
])
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
([
bs
,
maxlen
])
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
if
xs
is
not
None
:
assert
paddle
.
shape
(
xs
)[
0
]
==
bs
,
(
paddle
.
shape
(
xs
)[
0
],
bs
)
if
length_dim
<
0
:
length_dim
=
len
(
paddle
.
shape
(
xs
))
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
len
(
paddle
.
shape
(
xs
))))
mask
=
paddle
.
expand
(
mask
[
ind
],
paddle
.
shape
(
xs
))
return
mask
def
make_non_pad_mask
(
lengths
,
length_dim
=-
1
):
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
"""Make mask tensor containing indices of non-padded part.
Args:
Args:
...
@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1):
...
@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1):
Returns:
Returns:
Tensor(bool): mask tensor containing indices of padded part bool.
Tensor(bool): mask tensor containing indices of padded part bool.
Examples:
Examples:
With only lengths.
With only lengths.
>>> lengths = [5, 3, 2]
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]])
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
"""
"""
return
paddle
.
logical_not
(
make_pad_mask
(
lengths
,
length_dim
))
return
paddle
.
logical_not
(
make_pad_mask
(
lengths
,
xs
,
length_dim
))
def
initialize
(
model
:
nn
.
Layer
,
init
:
str
):
def
initialize
(
model
:
nn
.
Layer
,
init
:
str
):
...
...
ernie-sat/run_clone_en_to_zh.sh
浏览文件 @
b81832ce
...
@@ -10,8 +10,8 @@ python inference.py \
...
@@ -10,8 +10,8 @@ python inference.py \
--uid
=
Prompt_003_new
\
--uid
=
Prompt_003_new
\
--new_str
=
'今天天气很好.'
\
--new_str
=
'今天天气很好.'
\
--prefix
=
'./prompt/dev/'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--source_lang
=
english
\
--target_lang
uage
=
chinese
\
--target_lang
=
chinese
\
--output_name
=
pred_clone.wav
\
--output_name
=
pred_clone.wav
\
--use_pt_vocoder
=
False
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/run_gen_en.sh
浏览文件 @
b81832ce
...
@@ -9,8 +9,8 @@ python inference.py \
...
@@ -9,8 +9,8 @@ python inference.py \
--uid
=
p299_096
\
--uid
=
p299_096
\
--new_str
=
'I enjoy my life, do you?'
\
--new_str
=
'I enjoy my life, do you?'
\
--prefix
=
'./prompt/dev/'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--source_lang
=
english
\
--target_lang
uage
=
english
\
--target_lang
=
english
\
--output_name
=
pred_gen.wav
\
--output_name
=
pred_gen.wav
\
--use_pt_vocoder
=
False
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/run_sedit_en.sh
浏览文件 @
b81832ce
...
@@ -10,8 +10,8 @@ python inference.py \
...
@@ -10,8 +10,8 @@ python inference.py \
--uid
=
p243_new
\
--uid
=
p243_new
\
--new_str
=
'for that reason cover is impossible to be given.'
\
--new_str
=
'for that reason cover is impossible to be given.'
\
--prefix
=
'./prompt/dev/'
\
--prefix
=
'./prompt/dev/'
\
--source_lang
uage
=
english
\
--source_lang
=
english
\
--target_lang
uage
=
english
\
--target_lang
=
english
\
--output_name
=
pred_edit.wav
\
--output_name
=
pred_edit.wav
\
--use_pt_vocoder
=
False
\
--use_pt_vocoder
=
False
\
--voc
=
pwgan_aishell3
\
--voc
=
pwgan_aishell3
\
...
...
ernie-sat/sedit_arg_parser.py
浏览文件 @
b81832ce
...
@@ -80,10 +80,8 @@ def parse_args():
...
@@ -80,10 +80,8 @@ def parse_args():
parser
.
add_argument
(
"--uid"
,
type
=
str
,
help
=
"uid"
)
parser
.
add_argument
(
"--uid"
,
type
=
str
,
help
=
"uid"
)
parser
.
add_argument
(
"--new_str"
,
type
=
str
,
help
=
"new string"
)
parser
.
add_argument
(
"--new_str"
,
type
=
str
,
help
=
"new string"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
help
=
"prefix"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
help
=
"prefix"
)
parser
.
add_argument
(
"--clone_prefix"
,
type
=
str
,
default
=
None
,
help
=
"clone prefix"
)
parser
.
add_argument
(
"--source_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"source language"
)
parser
.
add_argument
(
"--clone_uid"
,
type
=
str
,
default
=
None
,
help
=
"clone uid"
)
parser
.
add_argument
(
"--target_lang"
,
type
=
str
,
default
=
"english"
,
help
=
"target language"
)
parser
.
add_argument
(
"--source_language"
,
type
=
str
,
help
=
"source language"
)
parser
.
add_argument
(
"--target_language"
,
type
=
str
,
help
=
"target language"
)
parser
.
add_argument
(
"--output_name"
,
type
=
str
,
help
=
"output name"
)
parser
.
add_argument
(
"--output_name"
,
type
=
str
,
help
=
"output name"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
help
=
"task name"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
help
=
"task name"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
ernie-sat/tools/
parallel_wavegan_pretrained_vocoder
.py
→
ernie-sat/tools/
torch_pwgan
.py
浏览文件 @
b81832ce
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
import
yaml
import
yaml
class
ParallelWaveGANPretrainedVocoder
(
torch
.
nn
.
Module
):
class
TorchPWGAN
(
torch
.
nn
.
Module
):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def
__init__
(
def
__init__
(
...
...
ernie-sat/utils.py
浏览文件 @
b81832ce
import
os
from
typing
import
List
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
yaml
import
yaml
...
@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
...
@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
paddlespeech.t2s.modules.normalizer
import
ZScore
from
tools.parallel_wavegan_pretrained_vocoder
import
ParallelWaveGANPretrainedVocoder
from
tools.torch_pwgan
import
TorchPWGAN
# new add
model_alias
=
{
model_alias
=
{
# acoustic model
# acoustic model
...
@@ -25,6 +26,10 @@ model_alias = {
...
@@ -25,6 +26,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2"
,
"paddlespeech.t2s.models.tacotron2:Tacotron2"
,
"tacotron2_inference"
:
"tacotron2_inference"
:
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference"
,
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference"
,
"pwgan"
:
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator"
,
"pwgan_inference"
:
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
}
}
...
@@ -43,60 +48,65 @@ def build_vocoder_from_file(
...
@@ -43,60 +48,65 @@ def build_vocoder_from_file(
# Build vocoder
# Build vocoder
if
str
(
vocoder_file
).
endswith
(
".pkl"
):
if
str
(
vocoder_file
).
endswith
(
".pkl"
):
# If the extension is ".pkl", the model is trained with parallel_wavegan
# If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder
=
ParallelWaveGANPretrainedVocoder
(
vocoder_file
,
vocoder
=
TorchPWGAN
(
vocoder_file
,
vocoder_config_file
)
vocoder_config_file
)
return
vocoder
.
to
(
device
)
return
vocoder
.
to
(
device
)
else
:
else
:
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
def
get_voc_out
(
mel
,
target_lang
uage
=
"chinese"
):
def
get_voc_out
(
mel
,
target_lang
:
str
=
"chinese"
):
# vocoder
# vocoder
args
=
parse_args
()
args
=
parse_args
()
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"In get_voc_out function, target_language
is illegal..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In get_voc_out function, target_lang
is illegal..."
# print("current vocoder: ", args.voc)
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
# print(voc_config)
voc_inference
=
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_inference
=
get_voc_inference
(
args
,
voc_config
)
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_stat
=
args
.
voc_stat
)
mel
=
paddle
.
to_tensor
(
mel
)
# print("masked_mel: ", mel.shape)
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
wav
=
voc_inference
(
mel
)
wav
=
voc_inference
(
mel
)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return
np
.
squeeze
(
wav
)
return
np
.
squeeze
(
wav
)
# dygraph
# dygraph
def
get_am_inference
(
args
,
am_config
):
def
get_am_inference
(
am
:
str
=
'fastspeech2_csmsc'
,
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
am_config
:
CfgNode
=
None
,
am_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
am_stat
:
Optional
[
os
.
PathLike
]
=
None
,
phones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
tones_dict
:
Optional
[
os
.
PathLike
]
=
None
,
speaker_dict
:
Optional
[
os
.
PathLike
]
=
None
,
return_am
:
bool
=
False
):
with
open
(
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
vocab_size
=
len
(
phn_id
)
#
print("vocab_size:", vocab_size)
print
(
"vocab_size:"
,
vocab_size
)
tone_size
=
None
tone_size
=
None
if
'tones_dict'
in
args
and
args
.
tones_dict
:
if
tones_dict
is
not
None
:
with
open
(
args
.
tones_dict
,
"r"
)
as
f
:
with
open
(
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
tone_size
=
len
(
tone_id
)
print
(
"tone_size:"
,
tone_size
)
print
(
"tone_size:"
,
tone_size
)
spk_num
=
None
spk_num
=
None
if
'speaker_dict'
in
args
and
args
.
speaker_dict
:
if
speaker_dict
is
not
None
:
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
with
open
(
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
spk_num
=
len
(
spk_id
)
print
(
"spk_num:"
,
spk_num
)
print
(
"spk_num:"
,
spk_num
)
odim
=
am_config
.
n_mels
odim
=
am_config
.
n_mels
# model: {model_name}_{dataset}
# model: {model_name}_{dataset}
am_name
=
a
rgs
.
am
[:
args
.
am
.
rindex
(
'_'
)]
am_name
=
a
m
[:
am
.
rindex
(
'_'
)]
am_dataset
=
a
rgs
.
am
[
args
.
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
a
m
[
am
.
rindex
(
'_'
)
+
1
:]
am_class
=
dynamic_import
(
am_name
,
model_alias
)
am_class
=
dynamic_import
(
am_name
,
model_alias
)
am_inference_class
=
dynamic_import
(
am_name
+
'_inference'
,
model_alias
)
am_inference_class
=
dynamic_import
(
am_name
+
'_inference'
,
model_alias
)
...
@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
...
@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
elif
am_name
==
'tacotron2'
:
elif
am_name
==
'tacotron2'
:
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
**
am_config
[
"model"
])
am
=
am_class
(
idim
=
vocab_size
,
odim
=
odim
,
**
am_config
[
"model"
])
am
.
set_state_dict
(
paddle
.
load
(
a
rgs
.
a
m_ckpt
)[
"main_params"
])
am
.
set_state_dict
(
paddle
.
load
(
am_ckpt
)[
"main_params"
])
am
.
eval
()
am
.
eval
()
am_mu
,
am_std
=
np
.
load
(
a
rgs
.
a
m_stat
)
am_mu
,
am_std
=
np
.
load
(
am_stat
)
am_mu
=
paddle
.
to_tensor
(
am_mu
)
am_mu
=
paddle
.
to_tensor
(
am_mu
)
am_std
=
paddle
.
to_tensor
(
am_std
)
am_std
=
paddle
.
to_tensor
(
am_std
)
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
am_normalizer
=
ZScore
(
am_mu
,
am_std
)
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
am_inference
.
eval
()
am_inference
.
eval
()
print
(
"acoustic model done!"
)
print
(
"acoustic model done!"
)
return
am
,
am_inference
,
am_name
,
am_dataset
,
phn_id
if
return_am
:
return
am_inference
,
am
else
:
return
am_inference
def
evaluate_durations
(
phns
,
def
get_voc_inference
(
target_language
=
"chinese"
,
voc
:
str
=
'pwgan_csmsc'
,
fs
=
24000
,
voc_config
:
Optional
[
os
.
PathLike
]
=
None
,
hop_length
=
300
):
voc_ckpt
:
Optional
[
os
.
PathLike
]
=
None
,
voc_stat
:
Optional
[
os
.
PathLike
]
=
None
,
):
# model: {model_name}_{dataset}
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
voc_class
=
dynamic_import
(
voc_name
,
model_alias
)
voc_inference_class
=
dynamic_import
(
voc_name
+
'_inference'
,
model_alias
)
if
voc_name
!=
'wavernn'
:
voc
=
voc_class
(
**
voc_config
[
"generator_params"
])
voc
.
set_state_dict
(
paddle
.
load
(
voc_ckpt
)[
"generator_params"
])
voc
.
remove_weight_norm
()
voc
.
eval
()
else
:
voc
=
voc_class
(
**
voc_config
[
"model"
])
voc
.
set_state_dict
(
paddle
.
load
(
voc_ckpt
)[
"main_params"
])
voc
.
eval
()
voc_mu
,
voc_std
=
np
.
load
(
voc_stat
)
voc_mu
=
paddle
.
to_tensor
(
voc_mu
)
voc_std
=
paddle
.
to_tensor
(
voc_std
)
voc_normalizer
=
ZScore
(
voc_mu
,
voc_std
)
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
voc_inference
.
eval
()
print
(
"voc done!"
)
return
voc_inference
def
evaluate_durations
(
phns
:
List
[
str
],
target_lang
:
str
=
"chinese"
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
args
=
parse_args
()
args
=
parse_args
()
if
target_lang
uage
==
'english'
:
if
target_lang
==
'english'
:
args
.
lang
=
'en'
args
.
lang
=
'en'
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
uage
==
'chinese'
:
elif
target_lang
==
'chinese'
:
args
.
lang
=
'zh'
args
.
lang
=
'zh'
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
if
args
.
ngpu
==
0
:
...
@@ -155,23 +187,28 @@ def evaluate_durations(phns,
...
@@ -155,23 +187,28 @@ def evaluate_durations(phns,
else
:
else
:
print
(
"ngpu should >= 0 !"
)
print
(
"ngpu should >= 0 !"
)
assert
target_lang
uage
==
"chinese"
or
target_language
==
"english"
,
"In evaluate_durations function, target_language
is illegal..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In evaluate_durations function, target_lang
is illegal..."
# Init body.
# Init body.
with
open
(
args
.
am_config
)
as
f
:
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
# print("========Config========")
# print(am_config)
am_inference
,
am
=
get_am_inference
(
# print("---------------------")
am
=
args
.
am
,
# acoustic model
am_config
=
am_config
,
am
,
am_inference
,
am_name
,
am_dataset
,
phn_id
=
get_am_inference
(
args
,
am_ckpt
=
args
.
am_ckpt
,
am_config
)
am_stat
=
args
.
am_stat
,
phones_dict
=
args
.
phones_dict
,
tones_dict
=
args
.
tones_dict
,
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
torch_phns
=
phns
torch_phns
=
phns
vocab_phones
=
{}
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_phones
[
tone
]
=
int
(
id
)
# print("vocab_phones: ", len(vocab_phones))
vocab_size
=
len
(
vocab_phones
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_phns
]
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_phns
]
...
@@ -185,59 +222,3 @@ def evaluate_durations(phns,
...
@@ -185,59 +222,3 @@ def evaluate_durations(phns,
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
return
phoneme_durations_new
return
phoneme_durations_new
def
sentence2phns
(
sentence
,
target_language
=
"en"
):
args
=
parse_args
()
if
target_language
==
'en'
:
args
.
lang
=
'en'
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_language
==
'zh'
:
args
.
lang
=
'zh'
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else
:
print
(
"target_language should in {'zh', 'en'}!"
)
frontend
=
get_frontend
(
args
)
merge_sentences
=
True
get_tone_ids
=
False
if
target_language
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
,
print_info
=
False
)
phone_ids
=
input_ids
[
"phone_ids"
]
phonemes
=
frontend
.
get_phonemes
(
sentence
,
merge_sentences
=
merge_sentences
,
print_info
=
False
)
return
phonemes
[
0
],
input_ids
[
"phone_ids"
][
0
]
elif
target_language
==
'en'
:
phonemes
=
frontend
.
phoneticize
(
sentence
)
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phones_list
=
[]
vocab_phones
=
{}
punc
=
":,;。?!“”‘’':,;.?!"
with
open
(
args
.
phones_dict
,
'rt'
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
phn
,
id
in
phn_id
:
vocab_phones
[
phn
]
=
int
(
id
)
phones
=
phonemes
[
1
:
-
1
]
phones
=
[
phn
for
phn
in
phones
if
not
phn
.
isspace
()]
# replace unk phone with sp
phones
=
[
phn
if
(
phn
in
vocab_phones
and
phn
not
in
punc
)
else
"sp"
for
phn
in
phones
]
phones_list
.
append
(
phones
)
return
phones_list
[
0
],
input_ids
[
"phone_ids"
][
0
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录