Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
9224659c
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9224659c
编写于
6月 15, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add docstring
上级
76b654cb
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1143 addition
and
1275 deletion
+1143
-1275
ernie-sat/README.md
ernie-sat/README.md
+4
-5
ernie-sat/align.py
ernie-sat/align.py
+160
-24
ernie-sat/collect_fn.py
ernie-sat/collect_fn.py
+217
-0
ernie-sat/dataset.py
ernie-sat/dataset.py
+147
-134
ernie-sat/inference.py
ernie-sat/inference.py
+242
-625
ernie-sat/mlm.py
ernie-sat/mlm.py
+148
-462
ernie-sat/mlm_loss.py
ernie-sat/mlm_loss.py
+53
-0
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
+96
-0
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
+60
-0
ernie-sat/read_text.py
ernie-sat/read_text.py
+2
-2
ernie-sat/sedit_arg_parser.py
ernie-sat/sedit_arg_parser.py
+0
-6
ernie-sat/utils.py
ernie-sat/utils.py
+14
-17
未找到文件。
ernie-sat/README.md
浏览文件 @
9224659c
...
...
@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
-
[
ERNIE-SAT_ZH
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en_zh.tar.gz
)
-
[
ERNIE-SAT_ZH
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en_zh.tar.gz
)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
...
...
@@ -108,7 +108,7 @@ prompt/dev
3.
`--voc`
声码器(vocoder)格式是否符合 {model_name}_{dataset}
4.
`--voc_config`
,
`--voc_checkpoint`
,
`--voc_stat`
是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5.
`--lang`
对应模型的语言可以是
`zh`
或
`en`
。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
7.
` --model_name`
模型名称
8.
` --uid`
特定提示(prompt)语音的 id
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
...
...
@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh
# 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
```
ernie-sat/align.py
浏览文件 @
9224659c
#!/usr/bin/env python
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import
multiprocessing
as
mp
import
os
import
sys
from
tqdm
import
tqdm
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
...
...
@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
...
...
@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try
:
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'_unk.words'
+
' '
+
tmpbase
+
'_unk.phons'
)
except
:
except
Exception
:
print
(
'english2phoneme error!'
)
sys
.
exit
(
1
)
...
...
@@ -148,19 +280,22 @@ def _get_user():
def
alignment
(
wav_path
:
str
,
text
:
str
):
'''
intervals: List[phn, start, end]
'''
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt_en
(
text
,
tmpbase
,
MODEL_DIR_EN
+
'/dict'
)
except
:
prep_txt_en
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_EN
+
'/dict'
)
except
Exception
:
print
(
'prep_txt error!'
)
return
None
...
...
@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
...
...
@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
return
None
...
...
@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
return
None
...
...
@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
...
...
@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
...
@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
def
alignment_zh
(
wav_path
,
text_string
):
def
alignment_zh
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
...
...
@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
unk_words
=
prep_txt_zh
(
text_string
,
tmpbase
,
MODEL_DIR_ZH
+
'/dict'
)
unk_words
=
prep_txt_zh
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
except
Exception
:
print
(
'prep_txt error!'
)
return
None
...
...
@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
...
...
@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
return
None
...
...
@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
return
None
...
...
@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
...
...
@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
...
@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
ernie-sat/collect_fn.py
0 → 100644
浏览文件 @
9224659c
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Union
import
numpy
as
np
import
paddle
from
dataset
import
get_seg_pos
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
seg_emb
=
seg_emb
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
seg_emb
=
self
.
seg_emb
,
text_masking
=
self
.
text_masking
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if
text_masking
:
masked_pos
,
text_masked_pos
=
phones_text_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
text_pad
=
text_pad
,
text_mask
=
text_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else
:
masked_pos
=
phones_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
output_dict
=
{}
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
=
speech_pad
,
text_pad
=
text_pad
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
seg_emb
=
seg_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output
=
(
uttids
,
output_dict
)
return
output
def
build_collate_fn
(
sr
:
int
=
24000
,
n_fft
:
int
=
2048
,
hop_length
:
int
=
300
,
win_length
:
int
=
None
,
n_mels
:
int
=
80
,
fmin
:
int
=
80
,
fmax
:
int
=
7600
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
train
:
bool
=
False
,
seg_emb
:
bool
=
False
,
epoch
:
int
=-
1
,
):
feats_extract_class
=
LogMelFBank
feats_extract
=
feats_extract_class
(
sr
=
sr
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
pad_speech
=
False
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
return
MLMCollateFn
(
feats_extract
=
feats_extract
,
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
mlm_prob
*
mlm_prob_factor
,
mean_phn_span
=
mean_phn_span
,
pad_speech
=
pad_speech
,
seg_emb
=
seg_emb
)
ernie-sat/dataset.py
浏览文件 @
9224659c
...
...
@@ -4,6 +4,68 @@ import numpy as np
import
paddle
# mask phones
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
# mask speech and phones
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
...
...
@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
float
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
unmasked_phn_idxs
=
list
(
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
...
...
@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
return
masked_pos
,
text_masked_pos
,
y_masks
return
masked_pos
,
text_masked_pos
def
get_seg_pos_reduce_duration
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
,
masked_pos
:
paddle
.
Tensor
,
feats_lens
:
paddle
.
Tensor
,
):
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
seg_emb
:
bool
=
False
):
'''
Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
text_seg_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
reordered_idx
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
durations
=
paddle
.
ones
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
max_reduced_length
=
0
if
not
sega_emb
:
return
speech_pad
,
masked_pos
,
speech_seg_pos
,
text_seg_pos
,
durations
if
not
seg_emb
:
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
first_idx
=
[]
last_idx
=
[]
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
if
j
==
0
:
if
paddle
.
sum
(
masked_pos
[
idx
][
0
:
s
])
==
0
:
first_idx
.
extend
(
range
(
0
,
s
))
else
:
first_idx
.
extend
([
0
])
last_idx
.
extend
(
range
(
1
,
s
))
if
paddle
.
sum
(
masked_pos
[
idx
][
s
:
e
])
==
0
:
first_idx
.
extend
(
range
(
s
,
e
))
else
:
first_idx
.
extend
([
s
])
last_idx
.
extend
(
range
(
s
+
1
,
e
))
durations
[
idx
][
s
]
=
e
-
s
speech_seg_pos
[
idx
][
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
][
j
]
=
j
+
1
max_reduced_length
=
max
(
len
(
first_idx
)
+
feats_lens
[
idx
]
-
e
,
max_reduced_length
)
first_idx
.
extend
(
range
(
e
,
speech_len
))
reordered_idx
[
idx
]
=
paddle
.
to_tensor
(
(
first_idx
+
last_idx
),
dtype
=
align_start_lens
.
dtype
)
feats_lens
[
idx
]
=
len
(
first_idx
)
reordered_idx
=
reordered_idx
[:,
:
max_reduced_length
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_len
s
return
speech_seg_pos
,
text_seg_po
s
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
,
mean_phn_span
:
float
):
# randomly select the range of speech and text to mask during training
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
float
=
8
):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
...
...
@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
np.ndarray:
a boolean tensor with shape [length]
"""
orig_length
=
length
...
...
@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
return
is_noise
[:
orig_length
]
def
pad_to_longformer_att_window
(
text
:
paddle
.
Tensor
,
max_len
:
int
,
max_tlen
:
int
,
attention_window
:
int
=
0
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
(
(
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
int
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
,
y_masks
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
if
not
sega_emb
:
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
speech_seg_pos
,
text_seg_pos
ernie-sat/inference.py
浏览文件 @
9224659c
#!/usr/bin/env python3
import
argparse
import
os
import
random
from
pathlib
import
Path
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Union
import
librosa
import
numpy
as
np
...
...
@@ -15,60 +11,42 @@ import paddle
import
soundfile
as
sf
import
torch
from
paddle
import
nn
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
align
import
alignment
from
align
import
alignment_zh
from
dataset
import
get_seg_pos
from
dataset
import
get_seg_pos_reduce_duration
from
dataset
import
pad_to_longformer_att_window
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
model_paddle
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2column_text
from
sedit_arg_parser
import
parse_args
from
utils
import
build_vocoder_from_file
from
utils
import
evaluate_durations
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
align
import
alignment
from
align
import
alignment_zh
from
align
import
words2phns
from
align
import
words2phns_zh
from
collect_fn
import
build_collate_fn
from
mlm
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
def
plot_mel_and_vocode_wav
(
uid
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
def
plot_mel_and_vocode_wav
(
wav_path
:
str
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"conformer"
,
full_origin_str
:
str
=
""
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
use_pt_vocoder
:
bool
=
False
,
sid
:
str
=
None
,
non_autoreg
:
bool
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
uid
=
uid
,
prefix
=
prefix
,
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
)
use_teacher_forcing
=
non_autoreg
)
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
...
...
@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str,
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
else
:
replaced_wav
=
get_voc_out
(
output_feat
,
target_lang
)
replaced_wav
=
get_voc_out
(
output_feat
)
elif
target_lang
==
'chinese'
:
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_lang
)
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
)
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
new_time_bdy
=
[
hop_length
*
x
for
x
in
new_span_bdy
]
...
...
@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str,
return
data_dict
,
old_span_bdy
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
...
...
@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
return
vocoder
def
load_model
(
model_name
:
str
):
def
load_model
(
model_name
:
str
=
"paddle_checkpoint_en"
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
mlm_model
,
args
=
build_model_from_file
(
mlm_model
,
conf
=
build_model_from_file
(
config_file
=
config_path
,
model_file
=
model_path
)
return
mlm_model
,
args
return
mlm_model
,
conf
def
read_data
(
uid
:
str
,
prefix
:
str
):
mfa_text
=
read_2column_text
(
prefix
+
'/text'
)[
uid
]
mfa_wav_path
=
read_2column_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
'mnt'
not
in
mfa_wav_path
:
mfa_wav_path
=
prefix
.
split
(
'dump'
)[
0
]
+
mfa_wav_path
def
read_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
# 获取 uid 对应的文本
mfa_text
=
read_2col_text
(
prefix
+
'/text'
)[
uid
]
# 获取 uid 对应的音频路径
mfa_wav_path
=
read_2col_text
(
prefix
+
'/wav.scp'
)[
uid
]
if
not
os
.
path
.
isabs
(
mfa_wav_path
):
mfa_wav_path
=
prefix
+
mfa_wav_path
return
mfa_text
,
mfa_wav_path
def
get_align_data
(
uid
:
str
,
prefix
:
str
):
def
get_align_data
(
uid
:
str
,
prefix
:
os
.
PathLike
):
mfa_path
=
prefix
+
"mfa_"
mfa_text
=
read_2col
umn
_text
(
mfa_path
+
'text'
)[
uid
]
mfa_text
=
read_2col_text
(
mfa_path
+
'text'
)[
uid
]
mfa_start
=
load_num_sequence_text
(
mfa_path
+
'start'
,
loader_type
=
'text_float'
)[
uid
]
mfa_end
=
load_num_sequence_text
(
mfa_path
+
'end'
,
loader_type
=
'text_float'
)[
uid
]
mfa_wav_path
=
read_2col
umn
_text
(
mfa_path
+
'wav.scp'
)[
uid
]
mfa_wav_path
=
read_2col_text
(
mfa_path
+
'wav.scp'
)[
uid
]
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def
get_masked_mel_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
fs
:
int
,
hop_length
:
int
,
span_to_repl
:
List
[
List
[
int
]]):
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
align_start
=
paddle
.
floor
(
fs
*
align_start
/
hop_length
).
int
(
)
align_end
=
paddle
.
floor
(
fs
*
align_end
/
hop_length
).
int
(
)
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
align_start
=
np
.
floor
(
fs
*
align_start
/
hop_length
).
astype
(
'int'
)
align_end
=
np
.
floor
(
fs
*
align_end
/
hop_length
).
astype
(
'int'
)
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
span_bdy
=
[
align_end
[
0
].
tolist
()[
-
1
],
align_end
[
0
].
tolist
()
[
-
1
]]
span_bdy
=
[
align_end
[
-
1
],
align_end
[
-
1
]]
else
:
span_bdy
=
[
align_start
[
0
].
tolist
()[
span_to_repl
[
0
]],
align_end
[
0
].
tolist
()[
span_to_repl
[
1
]
-
1
]
align_start
[
span_to_repl
[
0
]],
align_end
[
span_to_repl
[
1
]
-
1
]
]
return
span_bdy
return
span_bdy
,
align_start
,
align_end
def
recover_dict
(
word2phns
:
Dict
[
str
,
str
],
tp_word2phns
:
Dict
[
str
,
str
]):
...
...
@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
return
dic
def
get_max_idx
(
dic
):
return
sorted
([
int
(
key
.
split
(
'_'
)[
0
])
for
key
in
dic
.
keys
()])[
-
1
]
def
get_phns_and_spans
(
wav_path
:
str
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
):
append_new_str
=
(
old_str
==
new_str
[:
len
(
old_str
)])
is_append
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
# source
if
source_lang
==
"english"
:
times2
,
word2phns
=
alignment
(
wav_path
,
old_str
)
intervals
,
word2phns
=
alignment
(
wav_path
,
old_str
)
elif
source_lang
==
"chinese"
:
times2
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
intervals
,
word2phns
=
alignment_zh
(
wav_path
,
old_str
)
_
,
tp_word2phns
=
words2phns_zh
(
old_str
)
for
key
,
value
in
tp_word2phns
.
items
():
...
...
@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str,
tp_word2phns
[
key
]
=
cur_val
word2phns
=
recover_dict
(
word2phns
,
tp_word2phns
)
else
:
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
"source_lang is wrong..."
assert
source_lang
==
"chinese"
or
source_lang
==
"english"
,
\
"source_lang is wrong..."
for
item
in
times2
:
for
item
in
intervals
:
old_phns
.
append
(
item
[
0
])
mfa_start
.
append
(
float
(
item
[
1
]))
mfa_end
.
append
(
float
(
item
[
2
]))
old_phns
.
append
(
item
[
0
])
if
append_new_str
and
(
source_lang
!=
target_lang
):
is_cross_lingual_clone
=
True
# target
if
is_append
and
(
source_lang
!=
target_lang
):
cross_lingual_clone
=
True
else
:
is_
cross_lingual_clone
=
False
cross_lingual_clone
=
False
if
is_
cross_lingual_clone
:
new_
str_origin
=
new_str
[:
len
(
old_str
)]
new_
str_append
=
new_str
[
len
(
old_str
):]
if
cross_lingual_clone
:
str_origin
=
new_str
[:
len
(
old_str
)]
str_append
=
new_str
[
len
(
old_str
):]
if
target_lang
==
"chinese"
:
new_phns_origin
,
new_origin_word2phns
=
words2phns
(
new_str_origin
)
new_phns_append
,
temp_new_append_word2phns
=
words2phns_zh
(
new_str_append
)
phns_origin
,
origin_word2phns
=
words2phns
(
str_origin
)
phns_append
,
append_word2phns_tmp
=
words2phns_zh
(
str_append
)
elif
target_lang
==
"english"
:
# 原始句子
new_phns_origin
,
new_origin_word2phns
=
words2phns_zh
(
new_str_origin
)
# clone句子
new_phns_append
,
temp_new_append_word2phns
=
words2phns
(
new_str_append
)
phns_origin
,
origin_word2phns
=
words2phns_zh
(
str_origin
)
# clone 句子
phns_append
,
append_word2phns_tmp
=
words2phns
(
str_append
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"cloning is not support for this language, please check it."
new_phns
=
new_phns_origin
+
new_
phns_append
new_phns
=
phns_origin
+
phns_append
new_
append_word2phns
=
{}
length
=
len
(
new_
origin_word2phns
)
for
key
,
value
in
temp_new_append_word2phns
.
items
():
append_word2phns
=
{}
length
=
len
(
origin_word2phns
)
for
key
,
value
in
append_word2phns_tmp
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
new_append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_word2phns
=
dict
(
list
(
new_origin_word2phns
.
items
())
+
list
(
new_append_word2phns
.
items
()))
append_word2phns
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_word2phns
=
origin_word2phns
.
copy
()
new_word2phns
.
update
(
append_word2phns
)
else
:
if
source_lang
==
target_lang
and
target_lang
==
"english"
:
...
...
@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str,
right_idx
=
0
new_phns_right
=
[]
sp_count
=
0
word2phns_max_idx
=
int
(
list
(
word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
]
)
new_word2phns_max_idx
=
int
(
list
(
new_word2phns
.
keys
())[
-
1
].
split
(
'_'
)[
0
]
)
word2phns_max_idx
=
get_max_idx
(
word2phns
)
new_word2phns_max_idx
=
get_max_idx
(
new_word2phns
)
new_phns_mid
=
[]
if
append_new_str
:
if
is_append
:
new_phns_right
=
[]
new_phns_mid
=
new_phns
[
left_idx
:]
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
# speech edit
else
:
for
key
in
list
(
word2phns
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
...
...
@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str,
len
(
old_phns
))
break
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
def
duration_adjust_factor
(
original_dur
:
List
[
int
],
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
duration_adjust_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
factor_list
=
[]
for
ori
,
pred
,
phn
in
zip
(
original
_dur
,
pred_dur
,
phns
):
for
ori
g
,
pred
,
phn
in
zip
(
orig
_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
continue
else
:
factor_list
.
append
(
ori
/
pred
)
factor_list
.
append
(
ori
g
/
pred
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
return
1
length
=
2
return
np
.
average
(
factor_list
[
length
:
-
length
])
def
prepare_features_with_duration
(
uid
:
str
,
prefix
:
str
,
wav_path
:
str
,
mlm_model
:
nn
.
Layer
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
wav_org
,
rate
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
avg
=
np
.
average
(
factor_list
[
length
:
-
length
])
return
avg
def
prep_feats_with_dur
(
wav_path
:
str
,
mlm_model
:
nn
.
Layer
,
source_lang
:
str
=
"English"
,
target_lang
:
str
=
"English"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
mask_reconstruct
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_to_repl
,
span_to_add
=
get_phns_and_spans
(
wav_path
=
wav_path
,
...
...
@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str,
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
if
target_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
target_lang
)
elif
target_lang
==
"chinese"
:
if
source_lang
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
elif
source_lang
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
old_durs
=
evaluate_durations
(
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"calculate duration_predict is not support for this language..."
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
orig
inal_old_duration
s
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
orig
_old_dur
s
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
span_to_add
=
span_to_repl
d_factor_left
=
duration_adjust_factor
(
original_old_durations
[:
span_to_repl
[
0
]],
old_durations
[:
span_to_repl
[
0
]],
old_phns
[:
span_to_repl
[
0
]])
orig_dur
=
orig_old_durs
[:
span_to_repl
[
0
]],
pred_dur
=
old_durs
[:
span_to_repl
[
0
]],
phns
=
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
duration_adjust_factor
(
original_old_durations
[
span_to_repl
[
1
]:],
old_durations
[
span_to_repl
[
1
]:],
old_phns
[
span_to_repl
[
1
]:])
orig_dur
=
orig_old_durs
[
span_to_repl
[
1
]:],
pred_dur
=
old_durs
[
span_to_repl
[
1
]:],
phns
=
old_phns
[
span_to_repl
[
1
]:])
d_factor
=
(
d_factor_left
+
d_factor_right
)
/
2
new_dur
ations_adjusted
=
[
d_factor
*
i
for
i
in
old_duration
s
]
new_dur
s_adjusted
=
[
d_factor
*
i
for
i
in
old_dur
s
]
else
:
if
duration_adjust
:
d_factor
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
d_factor
=
duration_adjust_factor
(
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
print
(
"d_factor:"
,
d_factor
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
==
"english"
:
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
=
target_lang
)
elif
target_lang
==
"chinese"
:
new_durations
=
evaluate_durations
(
new_phns
,
target_lang
=
target_lang
)
new_durations_adjusted
=
[
d_factor
*
i
for
i
in
new_durations
]
if
span_to_repl
[
0
]
<
len
(
old_phns
)
and
old_phns
[
span_to_repl
[
0
]]
==
new_phns
[
span_to_add
[
0
]]:
new_durations_adjusted
[
span_to_add
[
0
]]
=
original_old_durations
[
span_to_repl
[
0
]]
if
span_to_repl
[
1
]
<
len
(
old_phns
)
and
span_to_add
[
1
]
<
len
(
new_phns
):
if
old_phns
[
span_to_repl
[
1
]]
==
new_phns
[
span_to_add
[
1
]]:
new_durations_adjusted
[
span_to_add
[
1
]]
=
original_old_durations
[
span_to_repl
[
1
]]
new_span_duration_sum
=
sum
(
new_durations_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_duration_sum
=
sum
(
original_old_durations
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
duration_offset
=
new_span_duration_sum
-
old_span_duration_sum
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
new_durs
=
evaluate_durations
(
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
new_durs
]
new_span_dur_sum
=
sum
(
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_dur_sum
=
sum
(
orig_old_durs
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
dur_offset
=
new_span_dur_sum
-
old_span_dur_sum
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
for
i
in
new_dur
ation
s_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
for
i
in
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
i
)
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
i
)
new_mfa_start
+=
[
i
+
dur
ation
_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur
ation
_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
new_mfa_start
+=
[
i
+
dur_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
# 3. get new wav
# 在原始句子后拼接
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
left_idx
=
len
(
wav_org
)
right_idx
=
left_idx
# 在原始句子中间替换
else
:
left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
fs
))
right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
fs
))
new_blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_duration_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
new_wav_org
=
np
.
concatenate
(
[
wav_org
[:
left_idx
],
new_blank_wav
,
wav_org
[
right_idx
:]])
blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_dur_sum
*
fs
)),
),
dtype
=
wav_org
.
dtype
)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav
=
np
.
concatenate
(
[
wav_org
[:
left_idx
],
blank_wav
,
wav_org
[
right_idx
:]])
# 4. get old and new mel span to be mask
# [92, 92]
old_span_bdy
=
get_masked_mel_bdy
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_to_repl
)
old_span_bdy
,
mfa_start
,
mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
mfa_start
,
mfa_end
=
mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_repl
)
# [92, 174]
new_span_bdy
=
get_masked_mel_bdy
(
new_mfa_start
,
new_mfa_end
,
fs
,
hop_length
,
span_to_add
)
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prepare_features
(
uid
:
str
,
mlm_model
:
nn
.
Layer
,
processor
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
train_args
=
None
):
wav_org
,
phns_list
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prepare_features_with_duration
(
uid
=
uid
,
prefix
=
prefix
,
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
new_span_bdy
,
new_mfa_start
,
new_mfa_end
=
get_masked_mel_bdy
(
mfa_start
=
new_mfa_start
,
mfa_end
=
new_mfa_end
,
fs
=
fs
,
hop_length
=
hop_length
,
span_to_repl
=
span_to_add
)
# old_span_bdy, new_span_bdy 是帧级别的范围
return
new_wav
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_bdy
,
new_span_bdy
def
prep_feats
(
mlm_model
:
nn
.
Layer
,
wav_path
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
mask_reconstruct
:
bool
=
False
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
wav
,
phns
,
mfa_start
,
mfa_end
,
old_span_bdy
,
new_span_bdy
=
prep_feats_with_dur
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
old_str
=
old_str
,
new_str
=
new_str
,
wav_path
=
wav_path
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
mask_reconstruct
=
mask_reconstruct
,
train_args
=
train_args
)
speech
=
wav_org
align_start
=
np
.
array
(
mfa_start
)
align_end
=
np
.
array
(
mfa_end
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
fs
=
fs
,
hop_length
=
hop_length
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns
)))
span_bdy
=
np
.
array
(
new_span_bdy
)
batch
=
[(
'1'
,
{
"speech"
:
speech
,
"align_start"
:
align
_start
,
"align_end"
:
align
_end
,
"speech"
:
wav
,
"align_start"
:
mfa
_start
,
"align_end"
:
mfa
_end
,
"text"
:
text
,
"span_bdy"
:
span_bdy
})]
...
...
@@ -648,375 +505,135 @@ def prepare_features(uid: str,
return
batch
,
old_span_bdy
,
new_span_bdy
def
decode_with_model
(
uid
:
str
,
mlm_model
:
nn
.
Layer
,
processor
,
def
decode_with_model
(
mlm_model
:
nn
.
Layer
,
collate_fn
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
,
train_args
=
None
):
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
batch
,
old_span_bdy
,
new_span_bdy
=
prepare_features
(
uid
=
uid
,
prefix
=
prefix
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
batch
,
old_span_bdy
,
new_span_bdy
=
prep_feats
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
processor
=
processor
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
fs
=
fs
,
hop_length
=
hop_length
,
token_list
=
token_list
)
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
for
k
,
v
in
feats
.
items
():
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
rtn
=
mlm_model
.
inference
(
**
feats
,
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
output
=
rtn
[
'feat_gen'
]
output
=
mlm_model
.
inference
(
text
=
feats
[
'text'
],
speech
=
feats
[
'speech'
],
masked_pos
=
feats
[
'masked_pos'
],
speech_mask
=
feats
[
'speech_mask'
],
text_mask
=
feats
[
'text_mask'
],
speech_seg_pos
=
feats
[
'speech_seg_pos'
],
text_seg_pos
=
feats
[
'text_seg_pos'
],
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
if
0
in
output
[
0
].
shape
and
0
not
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
()],
axis
=
0
)
.
cpu
()
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
()],
axis
=
0
)
elif
0
not
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
[
output
[
0
].
squeeze
()]
+
output
[
1
:
-
1
],
axis
=
0
)
.
cpu
()
[
output
[
0
].
squeeze
()]
+
output
[
1
:
-
1
],
axis
=
0
)
elif
0
in
output
[
0
].
shape
and
0
in
output
[
-
1
].
shape
:
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
],
axis
=
0
)
.
cpu
()
output_feat
=
paddle
.
concat
(
output
[
1
:
-
1
],
axis
=
0
)
else
:
output_feat
=
paddle
.
concat
(
[
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
axis
=
0
).
cpu
()
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
return
wav_org
,
None
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
sega_emb
=
sega_emb
self
.
duration_collect
=
duration_collect
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
sega_emb
=
self
.
sega_emb
,
duration_collect
=
self
.
duration_collect
,
text_masking
=
self
.
text_masking
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
if
'text'
not
in
output
:
text
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
.
unsqueeze
(
-
1
)))
-
2
text_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
+
1
max_tlen
=
1
align_start
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_end
=
paddle
.
zeros
(
paddle
.
shape
(
text
))
align_start_lens
=
paddle
.
zeros
(
paddle
.
shape
(
feats_lens
))
sega_emb
=
False
mean_phn_span
=
0
mlm_prob
=
0.15
else
:
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
feats_extract
.
hop_length
).
int
()
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
feats_extract
.
hop_length
).
int
()
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
if
attention_window
>
0
and
pad_speech
:
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
speech_pad
,
max_slen
,
max_slen
,
attention_window
)
max_len
=
max_slen
+
max_tlen
if
attention_window
>
0
:
text_pad
,
max_tlen
=
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
)
else
:
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
if
attention_window
>
0
:
text_mask
=
text_mask
*
2
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
if
text_masking
:
masked_pos
,
text_masked_pos
,
_
=
phones_text_masking
(
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
else
:
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
masked_pos
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_start_lens
,
mlm_prob
,
mean_phn_span
,
span_bdy
)
output_dict
=
{}
if
duration_collect
and
'text'
in
output
:
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_lens
=
get_seg_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
,
masked_pos
,
feats_lens
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:
reordered_idx
.
shape
[
1
],
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
output_dict
[
'durations'
]
=
durations
output_dict
[
'reordered_idx'
]
=
reordered_idx
else
:
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lens
,
sega_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output_dict
[
'speech_lens'
]
=
output
[
"speech_lens"
]
output_dict
[
'text_lens'
]
=
text_lens
output
=
(
uttids
,
output_dict
)
return
output
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class
=
LogMelFBank
if
args
.
feats_extract_conf
[
'win_length'
]
is
None
:
args
.
feats_extract_conf
[
'win_length'
]
=
args
.
feats_extract_conf
[
'n_fft'
]
args_dic
=
{}
for
k
,
v
in
args
.
feats_extract_conf
.
items
():
if
k
==
'fs'
:
args_dic
[
'sr'
]
=
v
else
:
args_dic
[
k
]
=
v
feats_extract
=
feats_extract_class
(
**
args_dic
)
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
if
args
.
encoder_conf
[
'selfattention_layer_type'
]
==
'longformer'
:
attention_window
=
args
.
encoder_conf
[
'attention_window'
]
pad_speech
=
True
if
'pre_speech_layer'
in
args
.
encoder_conf
and
args
.
encoder_conf
[
'pre_speech_layer'
]
>
0
else
False
else
:
attention_window
=
0
pad_speech
=
False
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
(
)
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
duration_collect
=
True
else
:
duration_collect
=
False
return
MLMCollateFn
(
feats_extract
,
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
args
.
model_conf
[
'mlm_prob'
]
*
mlm_prob_factor
,
mean_phn_span
=
args
.
model_conf
[
'mean_phn_span'
],
attention_window
=
attention_window
,
pad_speech
=
pad_speech
,
sega_emb
=
sega_emb
,
duration_collect
=
duration_collect
)
def
get_mlm_output
(
uid
:
str
,
wav_path
:
str
,
prefix
:
str
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
axis
=
0
)
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
return
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
def
get_mlm_output
(
wav_path
:
str
,
model_name
:
str
=
"paddle_checkpoint_en"
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
duration_preditor_path
:
str
=
None
,
sid
:
str
=
None
,
decoder
:
bool
=
False
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
start_end_sp
:
bool
=
False
):
mlm_model
,
train_
args
=
load_model
(
model_name
)
mlm_model
,
train_
conf
=
load_model
(
model_name
)
mlm_model
.
eval
()
processor
=
None
collate_fn
=
build_collate_fn
(
train_args
,
False
)
collate_fn
=
build_collate_fn
(
sr
=
train_conf
.
feats_extract_conf
[
'fs'
],
n_fft
=
train_conf
.
feats_extract_conf
[
'n_fft'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
win_length
=
train_conf
.
feats_extract_conf
[
'win_length'
],
n_mels
=
train_conf
.
feats_extract_conf
[
'n_mels'
],
fmin
=
train_conf
.
feats_extract_conf
[
'fmin'
],
fmax
=
train_conf
.
feats_extract_conf
[
'fmax'
],
mlm_prob
=
train_conf
[
'mlm_prob'
],
mean_phn_span
=
train_conf
[
'mean_phn_span'
],
train
=
False
,
seg_emb
=
train_conf
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
)
return
decode_with_model
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
mlm_model
=
mlm_model
,
processor
=
processor
,
collate_fn
=
collate_fn
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
sid
,
decoder
=
decoder
,
use_teacher_forcing
=
use_teacher_forcing
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
fs
=
train_conf
.
feats_extract_conf
[
'fs'
],
hop_length
=
train_conf
.
feats_extract_conf
[
'hop_length'
],
token_list
=
train_conf
.
token_list
)
def
evaluate
(
uid
:
str
,
source_lang
:
str
=
"english"
,
target_lang
:
str
=
"english"
,
use_pt_vocoder
:
bool
=
False
,
prefix
:
str
=
"./prompt/dev/"
,
model_name
:
str
=
"conformer"
,
old_str
:
str
=
""
,
prefix
:
os
.
PathLike
=
"./prompt/dev/"
,
model_name
:
str
=
"paddle_checkpoint_en"
,
new_str
:
str
=
""
,
prompt_decoding
:
bool
=
False
,
task_name
:
str
=
None
):
duration_preditor_path
=
None
spemd
=
None
full_origin_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
# get origin text and path of origin wav
old_str
,
wav_path
=
read_data
(
uid
=
uid
,
prefix
=
prefix
)
if
task_name
==
'edit'
:
new_str
=
new_str
elif
task_name
==
'synthesize'
:
new_str
=
full_origin
_str
+
new_str
new_str
=
old
_str
+
new_str
else
:
new_str
=
full_origin_str
+
' '
.
join
(
[
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
new_str
=
old_str
+
' '
.
join
([
ch
for
ch
in
new_str
if
is_chinese
(
ch
)])
print
(
'new_str is '
,
new_str
)
if
not
old_str
:
old_str
=
full_origin_str
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
uid
=
uid
,
prefix
=
prefix
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
wav_path
=
wav_path
,
full_origin_str
=
full_origin_str
,
old_str
=
old_str
,
new_str
=
new_str
,
use_pt_vocoder
=
use_pt_vocoder
,
duration_preditor_path
=
duration_preditor_path
,
sid
=
spemd
)
use_pt_vocoder
=
use_pt_vocoder
)
return
results_dict
...
...
ernie-sat/m
odel_paddle
.py
→
ernie-sat/m
lm
.py
浏览文件 @
9224659c
import
argparse
import
logging
import
math
import
os
import
sys
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
...
...
@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath):
if
os
.
path
.
isdir
(
dir_path
):
sys
.
path
.
append
(
dir_path
)
from
paddlespeech.s2t.utils.error_rate
import
ErrorCalculator
from
paddlespeech.t2s.modules.activation
import
get_activation
from
paddlespeech.t2s.modules.conformer.convolution
import
ConvolutionModule
from
paddlespeech.t2s.modules.conformer.encoder_layer
import
EncoderLayer
from
paddlespeech.t2s.modules.masked_fill
import
masked_fill
from
paddlespeech.t2s.modules.nets_utils
import
initialize
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.transformer.embedding
import
LegacyRelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
PositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
ScaledPositionalEncoding
from
paddlespeech.t2s.modules.transformer.embedding
import
RelPositionalEncoding
from
paddlespeech.t2s.modules.transformer.subsampling
import
Conv2dSubsampling
from
paddlespeech.t2s.modules.transformer.attention
import
LegacyRelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
MultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.attention
import
RelPositionMultiHeadedAttention
from
paddlespeech.t2s.modules.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
...
...
@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo
from
paddlespeech.t2s.modules.transformer.repeat
import
repeat
from
paddlespeech.t2s.modules.layer_norm
import
LayerNorm
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
extend_pe
(
self
,
x
):
"""Reset the positional encodings."""
if
self
.
pe
is
not
None
:
if
paddle
.
shape
(
self
.
pe
)[
1
]
>=
paddle
.
shape
(
x
)[
1
]:
return
pe
=
paddle
.
zeros
((
paddle
.
shape
(
x
)[
1
],
self
.
d_model
))
if
self
.
reverse
:
position
=
paddle
.
arange
(
paddle
.
shape
(
x
)[
1
]
-
1
,
-
1
,
-
1.0
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
else
:
position
=
paddle
.
arange
(
0
,
paddle
.
shape
(
x
)[
1
],
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
pe
=
pe
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
extend_pe
(
x
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
:
paddle
.
shape
(
x
)[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
from
yacs.config
import
CfgNode
# MLM -> Mask Language Model
class
mySequential
(
nn
.
Sequential
):
def
forward
(
self
,
*
inputs
):
for
module
in
self
.
_sub_layers
.
values
():
...
...
@@ -108,12 +51,8 @@ class mySequential(nn.Sequential):
return
inputs
class
NewMaskInputLayer
(
nn
.
Layer
):
__constants__
=
[
'out_features'
]
out_features
:
int
def
__init__
(
self
,
out_features
:
int
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
class
MaskInputLayer
(
nn
.
Layer
):
def
__init__
(
self
,
out_features
:
int
)
->
None
:
super
().
__init__
()
self
.
mask_feature
=
paddle
.
create_parameter
(
shape
=
(
1
,
1
,
out_features
),
...
...
@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer):
default_initializer
=
paddle
.
nn
.
initializer
.
Assign
(
paddle
.
normal
(
shape
=
(
1
,
1
,
out_features
))))
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
=
None
)
->
paddle
.
Tensor
:
def
forward
(
self
,
input
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
=
None
)
->
paddle
.
Tensor
:
masked_pos
=
paddle
.
expand_as
(
paddle
.
unsqueeze
(
masked_pos
,
-
1
),
input
)
masked_input
=
masked_fill
(
input
,
masked_pos
,
0
)
+
masked_fill
(
paddle
.
expand_as
(
self
.
mask_feature
,
input
),
~
masked_pos
,
0
)
return
masked_input
class
LegacyRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
pos_bias_v
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b
,
h
,
t1
,
t2
=
paddle
.
shape
(
x
)
zero_pad
=
paddle
.
zeros
((
b
,
h
,
t1
,
1
))
x_padded
=
paddle
.
concat
([
zero_pad
,
x
],
axis
=-
1
)
x_padded
=
paddle
.
reshape
(
x_padded
,
[
b
,
h
,
t2
+
1
,
t1
])
# only keep the positions from 0 to time2
x
=
paddle
.
reshape
(
x_padded
[:,
:,
1
:],
[
b
,
h
,
t1
,
t2
])
if
self
.
zero_triu
:
ones
=
paddle
.
ones
((
t1
,
t2
))
x
=
x
*
paddle
.
tril
(
ones
,
t2
-
1
)[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
mask
):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# (batch, time1, head, d_k)
q
=
paddle
.
transpose
(
q
,
[
0
,
2
,
1
,
3
])
n_batch_pos
=
paddle
.
shape
(
pos_emb
)[
0
]
p
=
paddle
.
reshape
(
self
.
linear_pos
(
pos_emb
),
[
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
])
# (batch, head, time1, d_k)
p
=
paddle
.
transpose
(
p
,
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_u
=
paddle
.
transpose
((
q
+
self
.
pos_bias_u
),
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_v
=
paddle
.
transpose
((
q
+
self
.
pos_bias_v
),
[
0
,
2
,
1
,
3
])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
paddle
.
transpose
(
p
,
[
0
,
1
,
3
,
2
]))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
# (batch, head, time1, time2)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
class
MLMEncoder
(
nn
.
Layer
):
"""Conformer encoder module.
...
...
@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer):
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def
__init__
(
self
,
idim
,
vocab_size
=
0
,
idim
:
int
,
vocab_size
:
int
=
0
,
pre_speech_layer
:
int
=
0
,
attention_dim
=
256
,
attention_heads
=
4
,
linear_units
=
2048
,
num_blocks
=
6
,
dropout_rate
=
0.1
,
positional_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.0
,
input_layer
=
"conv2d"
,
normalize_before
=
True
,
concat_after
=
False
,
positionwise_layer_type
=
"linear"
,
positionwise_conv_kernel_size
=
1
,
macaron_style
=
False
,
pos_enc_layer_type
=
"abs_pos"
,
attention_dim
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
:
bool
=
False
,
pos_enc_layer_type
:
str
=
"abs_pos"
,
pos_enc_class
=
None
,
selfattention_layer_type
=
"selfattn"
,
activation_type
=
"swish"
,
use_cnn_module
=
False
,
zero_triu
=
False
,
cnn_module_kernel
=
31
,
padding_idx
=-
1
,
stochastic_depth_rate
=
0.0
,
intermediate_layers
=
None
,
text_masking
=
False
):
selfattention_layer_type
:
str
=
"selfattn"
,
activation_type
:
str
=
"swish"
,
use_cnn_module
:
bool
=
False
,
zero_triu
:
bool
=
False
,
cnn_module_kernel
:
int
=
31
,
padding_idx
:
int
=-
1
,
stochastic_depth_rate
:
float
=
0.0
,
text_masking
:
bool
=
False
):
"""Construct an Encoder object."""
super
().
__init__
()
self
.
_output_size
=
attention_dim
self
.
text_masking
=
text_masking
if
self
.
text_masking
:
self
.
text_masking_layer
=
New
MaskInputLayer
(
attention_dim
)
self
.
text_masking_layer
=
MaskInputLayer
(
attention_dim
)
activation
=
get_activation
(
activation_type
)
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
...
...
@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer):
elif
input_layer
==
"mlm"
:
self
.
segment_emb
=
None
self
.
speech_embed
=
mySequential
(
New
MaskInputLayer
(
idim
),
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
...
...
@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer):
self
.
segment_emb
=
nn
.
Embedding
(
500
,
attention_dim
,
padding_idx
=
padding_idx
)
self
.
speech_embed
=
mySequential
(
New
MaskInputLayer
(
idim
),
MaskInputLayer
(
idim
),
nn
.
Linear
(
idim
,
attention_dim
),
nn
.
LayerNorm
(
attention_dim
),
nn
.
ReLU
(),
...
...
@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer):
# self-attention module definition
if
selfattention_layer_type
==
"selfattn"
:
logging
.
info
(
"encoder self-attention layer type = self-attention"
)
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
...
...
@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer):
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"rel_selfattn"
:
logging
.
info
(
"encoder self-attention layer type = relative self-attention"
)
assert
pos_enc_layer_type
==
"rel_pos"
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
...
...
@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer):
if
self
.
normalize_before
:
self
.
after_norm
=
LayerNorm
(
attention_dim
)
self
.
intermediate_layers
=
intermediate_layers
def
forward
(
self
,
speech
_pad
,
text
_pad
,
masked_pos
,
speech_mask
=
None
,
text_mask
=
None
,
speech_seg_pos
=
None
,
text_seg_pos
=
None
):
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
=
None
,
text_mask
:
paddle
.
Tensor
=
None
,
speech_seg_pos
:
paddle
.
Tensor
=
None
,
text_seg_pos
:
paddle
.
Tensor
=
None
):
"""Encode input sequence.
"""
if
masked_pos
is
not
None
:
speech
_pad
=
self
.
speech_embed
(
speech_pad
,
masked_pos
)
speech
=
self
.
speech_embed
(
speech
,
masked_pos
)
else
:
speech_pad
=
self
.
speech_embed
(
speech_pad
)
# pure speech input
if
-
2
in
np
.
array
(
text_pad
):
text_pad
=
text_pad
+
3
text_mask
=
paddle
.
unsqueeze
(
bool
(
text_pad
),
1
)
text_seg_pos
=
paddle
.
zeros_like
(
text_pad
)
text_pad
=
self
.
text_embed
(
text_pad
)
text_pad
=
(
text_pad
[
0
]
+
self
.
segment_emb
(
text_seg_pos
),
text_pad
[
1
])
text_seg_pos
=
None
elif
text_pad
is
not
None
:
text_pad
=
self
.
text_embed
(
text_pad
)
speech
=
self
.
speech_embed
(
speech
)
if
text
is
not
None
:
text
=
self
.
text_embed
(
text
)
if
speech_seg_pos
is
not
None
and
text_seg_pos
is
not
None
and
self
.
segment_emb
:
speech_seg_emb
=
self
.
segment_emb
(
speech_seg_pos
)
text_seg_emb
=
self
.
segment_emb
(
text_seg_pos
)
text
_pad
=
(
text_pad
[
0
]
+
text_seg_emb
,
text_pad
[
1
])
speech
_pad
=
(
speech_pad
[
0
]
+
speech_seg_emb
,
speech_pad
[
1
])
text
=
(
text
[
0
]
+
text_seg_emb
,
text
[
1
])
speech
=
(
speech
[
0
]
+
speech_seg_emb
,
speech
[
1
])
if
self
.
pre_speech_encoders
:
speech
_pad
,
_
=
self
.
pre_speech_encoders
(
speech_pad
,
speech_mask
)
speech
,
_
=
self
.
pre_speech_encoders
(
speech
,
speech_mask
)
if
text
_pad
is
not
None
:
xs
=
paddle
.
concat
([
speech
_pad
[
0
],
text_pad
[
0
]],
axis
=
1
)
xs_pos_emb
=
paddle
.
concat
([
speech
_pad
[
1
],
text_pad
[
1
]],
axis
=
1
)
if
text
is
not
None
:
xs
=
paddle
.
concat
([
speech
[
0
],
text
[
0
]],
axis
=
1
)
xs_pos_emb
=
paddle
.
concat
([
speech
[
1
],
text
[
1
]],
axis
=
1
)
masks
=
paddle
.
concat
([
speech_mask
,
text_mask
],
axis
=-
1
)
else
:
xs
=
speech
_pad
[
0
]
xs_pos_emb
=
speech
_pad
[
1
]
xs
=
speech
[
0
]
xs_pos_emb
=
speech
[
1
]
masks
=
speech_mask
xs
,
masks
=
self
.
encoders
((
xs
,
xs_pos_emb
),
masks
)
...
...
@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer):
class
MLMDecoder
(
MLMEncoder
):
def
forward
(
self
,
xs
,
masks
,
masked_pos
=
None
,
segment_emb
=
None
):
def
forward
(
self
,
xs
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
):
"""Encode input sequence.
Args:
...
...
@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
"""
if
not
self
.
training
:
masked_pos
=
None
xs
=
self
.
embed
(
xs
)
if
segment_emb
:
xs
=
(
xs
[
0
]
+
segment_emb
,
xs
[
1
])
if
self
.
intermediate_layers
is
None
:
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
else
:
intermediate_outputs
=
[]
for
layer_idx
,
encoder_layer
in
enumerate
(
self
.
encoders
):
xs
,
masks
=
encoder_layer
(
xs
,
masks
)
if
(
self
.
intermediate_layers
is
not
None
and
layer_idx
+
1
in
self
.
intermediate_layers
):
encoder_output
=
xs
# intermediate branches also require normalization.
if
self
.
normalize_before
:
encoder_output
=
self
.
after_norm
(
encoder_output
)
intermediate_outputs
.
append
(
encoder_output
)
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
if
isinstance
(
xs
,
tuple
):
xs
=
xs
[
0
]
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
if
self
.
intermediate_layers
is
not
None
:
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
def
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
(
shape
=
(
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
class
MLMModel
(
nn
.
Layer
):
# encoder and decoder is nn.Layer, not str
class
MLM
(
nn
.
Layer
):
def
__init__
(
self
,
token_list
:
Union
[
Tuple
[
str
,
...],
List
[
str
]],
odim
:
int
,
...
...
@@ -557,44 +350,15 @@ class MLMModel(nn.Layer):
postnet_layers
:
int
=
0
,
postnet_chans
:
int
=
0
,
postnet_filts
:
int
=
0
,
ignore_id
:
int
=-
1
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
report_cer
:
bool
=
True
,
report_wer
:
bool
=
True
,
sym_space
:
str
=
"<space>"
,
sym_blank
:
str
=
"<blank>"
,
masking_schema
:
str
=
"span"
,
mean_phn_span
:
int
=
3
,
mlm_prob
:
float
=
0.25
,
dynamic_mlm_prob
=
False
,
decoder_seg_pos
=
False
,
text_masking
=
False
):
text_masking
:
bool
=
False
):
super
().
__init__
()
# note that eos is the same as sos (equivalent ID)
self
.
odim
=
odim
self
.
ignore_id
=
ignore_id
self
.
token_list
=
token_list
.
copy
()
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
vocab_size
=
encoder
.
text_embed
[
0
].
_num_embeddings
if
report_cer
or
report_wer
:
self
.
error_calculator
=
ErrorCalculator
(
token_list
,
sym_space
,
sym_blank
,
report_cer
,
report_wer
)
else
:
self
.
error_calculator
=
None
self
.
mlm_weight
=
1.0
self
.
mlm_prob
=
mlm_prob
self
.
mlm_layer
=
12
self
.
finetune_wo_mlm
=
True
self
.
max_span
=
50
self
.
min_span
=
4
self
.
mean_phn_span
=
mean_phn_span
self
.
masking_schema
=
masking_schema
if
self
.
decoder
is
None
or
not
(
hasattr
(
self
.
decoder
,
'output_layer'
)
and
self
.
decoder
.
output_layer
is
not
None
):
...
...
@@ -606,15 +370,9 @@ class MLMModel(nn.Layer):
self
.
encoder
.
text_embed
[
0
].
_embedding_dim
,
self
.
vocab_size
,
weight_attr
=
self
.
encoder
.
text_embed
[
0
].
_weight_attr
)
self
.
text_mlm_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_id
)
else
:
self
.
text_sfc
=
None
self
.
text_mlm_loss
=
None
self
.
decoder_seg_pos
=
decoder_seg_pos
if
lsm_weight
>
50
:
self
.
l1_loss_func
=
nn
.
MSELoss
()
else
:
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
self
.
encoder
.
_output_size
,
odim
=
odim
,
...
...
@@ -624,119 +382,77 @@ class MLMModel(nn.Layer):
use_batch_norm
=
True
,
dropout_rate
=
0.5
,
))
def
collect_feats
(
self
,
speech
,
speech_lens
,
text
,
text_lens
,
masked_pos
,
speech_mask
,
text_mask
,
speech_seg_pos
,
text_seg_pos
,
y_masks
=
None
)
->
Dict
[
str
,
paddle
.
Tensor
]:
return
{
"feats"
:
speech
,
"feats_lens"
:
speech_lens
}
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
if
self
.
decoder
is
not
None
:
ys_in
=
self
.
_add_first_frame_and_remove_last_frame
(
batch
[
'speech_pad'
])
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
ys_in
,
y_masks
,
encoder_out
,
bool
(
h_masks
),
self
.
encoder
.
segment_emb
(
speech_seg_pos
))
speech_hidden_states
=
zs
else
:
speech_hidden_states
=
encoder_out
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
])[
1
],
:]
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
(
paddle
.
shape
(
speech_hidden_states
)[
0
],
-
1
,
self
.
odim
))
else
:
before_outs
=
speech_hidden_states
if
self
.
postnet
is
not
None
:
after_outs
=
before_outs
+
paddle
.
transpose
(
self
.
postnet
(
paddle
.
transpose
(
before_outs
,
[
0
,
2
,
1
])),
(
0
,
2
,
1
))
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos'
]
def
inference
(
self
,
speech
,
text
,
masked_pos
,
speech_mask
,
text_mask
,
speech_seg_pos
,
text_seg_pos
,
span_bdy
,
y_masks
=
None
,
speech_lens
=
None
,
text_lens
=
None
,
feats
:
Optional
[
paddle
.
Tensor
]
=
None
,
spembs
:
Optional
[
paddle
.
Tensor
]
=
None
,
sids
:
Optional
[
paddle
.
Tensor
]
=
None
,
lids
:
Optional
[
paddle
.
Tensor
]
=
None
,
threshold
:
float
=
0.5
,
minlenratio
:
float
=
0.0
,
maxlenratio
:
float
=
10.0
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
,
span_bdy
:
List
[
int
],
use_teacher_forcing
:
bool
=
False
,
)
->
Dict
[
str
,
paddle
.
Tensor
]:
'''
Args:
speech (paddle.Tensor): input speech (B, Tmax, D).
text (paddle.Tensor): input text (B, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (B, Tmax)
speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
batch
=
dict
(
speech_pad
=
speech
,
text_pad
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
,
)
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs
=
[
batch
[
'speech_pad'
][:,
:
span_bdy
[
0
]]]
outs
=
[
speech
[:,
:
span_bdy
[
0
]]]
z_cache
=
None
if
use_teacher_forcing
:
before
,
zs
,
_
,
_
=
self
.
forward
(
batch
,
speech_seg_pos
,
y_masks
=
y_masks
)
before_outs
,
zs
,
*
_
=
self
.
forward
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
zs
is
None
:
zs
=
before
zs
=
before
_outs
outs
+=
[
zs
[
0
][
span_bdy
[
0
]:
span_bdy
[
1
]]]
outs
+=
[
batch
[
'speech_pad'
]
[:,
span_bdy
[
1
]:]]
return
dict
(
feat_gen
=
outs
)
outs
+=
[
speech
[:,
span_bdy
[
1
]:]]
return
outs
return
None
def
_add_first_frame_and_remove_last_frame
(
self
,
ys
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
ys_in
=
paddle
.
concat
(
[
paddle
.
zeros
(
shape
=
(
paddle
.
shape
(
ys
)[
0
],
1
,
paddle
.
shape
(
ys
)[
2
]),
dtype
=
ys
.
dtype
),
ys
[:,
:
-
1
]
],
axis
=
1
)
return
ys_in
class
MLMEncAsDecoderModel
(
MLMModel
):
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
class
MLMEncAsDecoder
(
MLM
):
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder
=
batch
[
'speech_pad'
]
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
encoder_out
,
h_masks
=
self
.
encoder
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
]
)[
1
],
:]
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
sfc
is
not
None
:
before_outs
=
paddle
.
reshape
(
self
.
sfc
(
speech_hidden_states
),
...
...
@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel):
[
0
,
2
,
1
])
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
speech_pad_placeholder
,
batch
[
'masked_pos'
]
class
MLMDualMaksingModel
(
MLMModel
):
def
_calc_mlm_loss
(
self
,
before_outs
:
paddle
.
Tensor
,
after_outs
:
paddle
.
Tensor
,
text_outs
:
paddle
.
Tensor
,
batch
):
xs_pad
=
batch
[
'speech_pad'
]
text_pad
=
batch
[
'text_pad'
]
masked_pos
=
batch
[
'masked_pos'
]
text_masked_pos
=
batch
[
'text_masked_pos'
]
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
if
after_outs
is
not
None
:
loss
+=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
after_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text_pad
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
def
forward
(
self
,
batch
,
speech_seg_pos
,
y_masks
=
None
):
return
before_outs
,
after_outs
,
None
class
MLMDualMaksing
(
MLM
):
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
speech_mask
:
paddle
.
Tensor
,
text_mask
:
paddle
.
Tensor
,
speech_seg_pos
:
paddle
.
Tensor
,
text_seg_pos
:
paddle
.
Tensor
):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out
,
h_masks
=
self
.
encoder
(
**
batch
)
# segment_emb
encoder_out
,
h_masks
=
self
.
encoder
(
speech
=
speech
,
text
=
text
,
masked_pos
=
masked_pos
,
speech_mask
=
speech_mask
,
text_mask
=
text_mask
,
speech_seg_pos
=
speech_seg_pos
,
text_seg_pos
=
text_seg_pos
)
if
self
.
decoder
is
not
None
:
zs
,
_
=
self
.
decoder
(
encoder_out
,
h_masks
)
else
:
zs
=
encoder_out
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
batch
[
'speech_pad'
]
)[
1
],
:]
speech_hidden_states
=
zs
[:,
:
paddle
.
shape
(
speech
)[
1
],
:]
if
self
.
text_sfc
:
text_hiddent_states
=
zs
[:,
paddle
.
shape
(
batch
[
'speech_pad'
])[
1
]:,
:]
text_hiddent_states
=
zs
[:,
paddle
.
shape
(
speech
)[
1
]:,
:]
text_outs
=
paddle
.
reshape
(
self
.
text_sfc
(
text_hiddent_states
),
(
paddle
.
shape
(
text_hiddent_states
)[
0
],
-
1
,
self
.
vocab_size
))
...
...
@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel):
[
0
,
2
,
1
])
else
:
after_outs
=
None
return
before_outs
,
after_outs
,
text_outs
,
None
#, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
return
before_outs
,
after_outs
,
text_outs
def
build_model_from_file
(
config_file
,
model_file
):
state_dict
=
paddle
.
load
(
model_file
)
model_class
=
MLMDualMaksing
Model
if
'conformer_combine_vctk_aishell3_dual_masking'
in
config_file
\
else
MLMEncAsDecoder
Model
model_class
=
MLMDualMaksing
if
'conformer_combine_vctk_aishell3_dual_masking'
in
config_file
\
else
MLMEncAsDecoder
# 构建模型
args
=
yaml
.
safe_load
(
Path
(
config_file
).
open
(
"r"
,
encoding
=
"utf-8"
))
args
=
argparse
.
Namespace
(
**
args
)
model
=
build_model
(
args
,
model_class
)
with
open
(
config_file
)
as
f
:
conf
=
CfgNode
(
yaml
.
safe_load
(
f
))
model
=
build_model
(
conf
,
model_class
)
model
.
set_state_dict
(
state_dict
)
return
model
,
args
return
model
,
conf
def
build_model
(
args
:
argparse
.
Namespace
,
model_class
=
MLMEncAsDecoderModel
)
->
MLMModel
:
# select encoder and decoder here
def
build_model
(
args
:
argparse
.
Namespace
,
model_class
=
MLMEncAsDecoder
)
->
MLM
:
if
isinstance
(
args
.
token_list
,
str
):
with
open
(
args
.
token_list
,
encoding
=
"utf-8"
)
as
f
:
token_list
=
[
line
.
rstrip
()
for
line
in
f
]
...
...
@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace,
token_list
=
list
(
args
.
token_list
)
else
:
raise
RuntimeError
(
"token_list must be str or list"
)
vocab_size
=
len
(
token_list
)
logging
.
info
(
f
"Vocabulary size:
{
vocab_size
}
"
)
vocab_size
=
len
(
token_list
)
odim
=
80
pos_enc_class
=
ScaledPositionalEncoding
if
args
.
use_scaled_pos_enc
else
PositionalEncoding
...
...
@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace,
if
conformer_rel_pos_type
==
"legacy"
:
if
conformer_pos_enc_layer_type
==
"rel_pos"
:
conformer_pos_enc_layer_type
=
"legacy_rel_pos"
logging
.
warning
(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
if
conformer_self_attn_layer_type
==
"rel_selfattn"
:
conformer_self_attn_layer_type
=
"legacy_rel_selfattn"
logging
.
warning
(
"Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'."
)
elif
conformer_rel_pos_type
==
"latest"
:
assert
conformer_pos_enc_layer_type
!=
"legacy_rel_pos"
assert
conformer_self_attn_layer_type
!=
"legacy_rel_selfattn"
...
...
ernie-sat/mlm_loss.py
0 → 100644
浏览文件 @
9224659c
import
paddle
from
paddle
import
nn
class
MLMLoss
(
nn
.
Layer
):
def
__init__
(
self
,
lsm_weight
:
float
=
0.1
,
ignore_id
:
int
=-
1
,
text_masking
:
bool
=
False
):
super
().
__init__
()
if
text_masking
:
self
.
text_mlm_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_id
)
if
lsm_weight
>
50
:
self
.
l1_loss_func
=
nn
.
MSELoss
()
else
:
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
text_masking
=
text_masking
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
before_outs
:
paddle
.
Tensor
,
after_outs
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
=
None
,
text_outs
:
paddle
.
Tensor
=
None
,
text_masked_pos
:
paddle
.
Tensor
=
None
):
xs_pad
=
speech
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
if
after_outs
is
not
None
:
loss
+=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
after_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
if
self
.
text_masking
:
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
return
loss_mlm
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
浏览文件 @
9224659c
...
...
@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
class
LegacyRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
pos_bias_v
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b
,
h
,
t1
,
t2
=
paddle
.
shape
(
x
)
zero_pad
=
paddle
.
zeros
((
b
,
h
,
t1
,
1
))
x_padded
=
paddle
.
concat
([
zero_pad
,
x
],
axis
=-
1
)
x_padded
=
paddle
.
reshape
(
x_padded
,
[
b
,
h
,
t2
+
1
,
t1
])
# only keep the positions from 0 to time2
x
=
paddle
.
reshape
(
x_padded
[:,
:,
1
:],
[
b
,
h
,
t1
,
t2
])
if
self
.
zero_triu
:
ones
=
paddle
.
ones
((
t1
,
t2
))
x
=
x
*
paddle
.
tril
(
ones
,
t2
-
1
)[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
mask
):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# (batch, time1, head, d_k)
q
=
paddle
.
transpose
(
q
,
[
0
,
2
,
1
,
3
])
n_batch_pos
=
paddle
.
shape
(
pos_emb
)[
0
]
p
=
paddle
.
reshape
(
self
.
linear_pos
(
pos_emb
),
[
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
])
# (batch, head, time1, d_k)
p
=
paddle
.
transpose
(
p
,
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_u
=
paddle
.
transpose
((
q
+
self
.
pos_bias_u
),
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_v
=
paddle
.
transpose
((
q
+
self
.
pos_bias_v
),
[
0
,
2
,
1
,
3
])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
paddle
.
transpose
(
p
,
[
0
,
1
,
3
,
2
]))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
# (batch, head, time1, time2)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
浏览文件 @
9224659c
...
...
@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size
=
paddle
.
shape
(
self
.
pe
)
pos_emb
=
self
.
pe
[:,
pe_size
[
1
]
//
2
-
T
+
1
:
pe_size
[
1
]
//
2
+
T
,
]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
extend_pe
(
self
,
x
):
"""Reset the positional encodings."""
if
self
.
pe
is
not
None
:
if
paddle
.
shape
(
self
.
pe
)[
1
]
>=
paddle
.
shape
(
x
)[
1
]:
return
pe
=
paddle
.
zeros
((
paddle
.
shape
(
x
)[
1
],
self
.
d_model
))
if
self
.
reverse
:
position
=
paddle
.
arange
(
paddle
.
shape
(
x
)[
1
]
-
1
,
-
1
,
-
1.0
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
else
:
position
=
paddle
.
arange
(
0
,
paddle
.
shape
(
x
)[
1
],
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
pe
=
pe
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
extend_pe
(
x
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
:
paddle
.
shape
(
x
)[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
ernie-sat/read_text.py
浏览文件 @
9224659c
...
...
@@ -5,7 +5,7 @@ from typing import List
from
typing
import
Union
def
read_2col
umn
_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
def
read_2col_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
"""Read a text file having 2 column as dict object.
Examples:
...
...
@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col
umn
_text('wav.scp')
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
...
...
ernie-sat/sedit_arg_parser.py
浏览文件 @
9224659c
...
...
@@ -65,12 +65,6 @@ def parse_args():
help
=
"mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser
.
add_argument
(
'--lang'
,
type
=
str
,
default
=
'en'
,
help
=
'Choose model language. zh or en'
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
# parser.add_argument("--test_metadata", type=str, help="test metadata.")
...
...
ernie-sat/utils.py
浏览文件 @
9224659c
...
...
@@ -32,7 +32,6 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
}
def
is_chinese
(
ch
):
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
return
True
...
...
@@ -55,12 +54,10 @@ def build_vocoder_from_file(
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
def
get_voc_out
(
mel
,
target_lang
:
str
=
"chinese"
):
def
get_voc_out
(
mel
):
# vocoder
args
=
parse_args
()
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
...
...
@@ -167,19 +164,23 @@ def get_voc_inference(
return
voc_inference
def
evaluate_durations
(
phns
:
List
[
str
],
target_lang
:
str
=
"chinese"
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
def
evaluate_durations
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
args
=
parse_args
()
if
target_lang
==
'english'
:
args
.
lang
=
'en'
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
==
'chinese'
:
args
.
lang
=
'zh'
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
...
...
@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else
:
print
(
"ngpu should >= 0 !"
)
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In evaluate_durations function, target_lang is illegal..."
# Init body.
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
...
...
@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
torch_phns
=
phns
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_
phns
]
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids_new
=
phone_ids
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
normalized_mel
,
d_outs
,
p_outs
,
e_outs
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录