Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
30986264
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30986264
编写于
6月 16, 2022
作者:
K
Kennycao123
提交者:
GitHub
6月 16, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #827 from yt605155624/format
[ernie sat]add docstring
上级
b81832ce
445e3040
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
1167 addition
and
1328 deletion
+1167
-1328
ernie-sat/README.md
ernie-sat/README.md
+4
-5
ernie-sat/align.py
ernie-sat/align.py
+160
-24
ernie-sat/collect_fn.py
ernie-sat/collect_fn.py
+217
-0
ernie-sat/dataset.py
ernie-sat/dataset.py
+147
-134
ernie-sat/inference.py
ernie-sat/inference.py
+258
-670
ernie-sat/mlm.py
ernie-sat/mlm.py
+150
-463
ernie-sat/mlm_loss.py
ernie-sat/mlm_loss.py
+53
-0
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
+96
-0
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
+60
-0
ernie-sat/read_text.py
ernie-sat/read_text.py
+2
-2
ernie-sat/sedit_arg_parser.py
ernie-sat/sedit_arg_parser.py
+0
-6
ernie-sat/utils.py
ernie-sat/utils.py
+20
-24
未找到文件。
ernie-sat/README.md
浏览文件 @
30986264
...
...
@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
-
[
ERNIE-SAT_ZH
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
://bj.bcebos.com/wenxin-models
/model-ernie-sat-base-en_zh.tar.gz
)
-
[
ERNIE-SAT_ZH
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-zh.tar.gz
)
-
[
ERNIE-SAT_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en.tar.gz
)
-
[
ERNIE-SAT_ZH_and_EN
](
http
s://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old
/model-ernie-sat-base-en_zh.tar.gz
)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
...
...
@@ -108,7 +108,7 @@ prompt/dev
3.
`--voc`
声码器(vocoder)格式是否符合 {model_name}_{dataset}
4.
`--voc_config`
,
`--voc_checkpoint`
,
`--voc_stat`
是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5.
`--lang`
对应模型的语言可以是
`zh`
或
`en`
。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
6.
`--ngpu`
要使用的
GPU
数,如果 ngpu==0,则使用 cpu。
7.
` --model_name`
模型名称
8.
` --uid`
特定提示(prompt)语音的 id
9.
` --new_str`
输入的文本(本次开源暂时先设置特定的文本)
...
...
@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh
# 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh
# 跨语言语音合成任务(英文到中文的语音克隆)
```
ernie-sat/align.py
浏览文件 @
30986264
#!/usr/bin/env python
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import
multiprocessing
as
mp
import
os
import
sys
from
tqdm
import
tqdm
PHONEME
=
'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN
=
'tools/aligner/english'
MODEL_DIR_ZH
=
'tools/aligner/mandarin'
...
...
@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY
=
'tools/htk/HTKTools/HCopy'
def
get_unk_phns
(
word_str
:
str
):
tmpbase
=
'/tmp/tp.'
f
=
open
(
tmpbase
+
'temp.words'
,
'w'
)
f
.
write
(
word_str
)
f
.
close
()
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'temp.words'
+
' '
+
tmpbase
+
'temp.phons'
)
f
=
open
(
tmpbase
+
'temp.phons'
,
'r'
)
lines2
=
f
.
readline
().
strip
().
split
()
f
.
close
()
phns
=
[]
for
phn
in
lines2
:
phons
=
phn
.
replace
(
'
\n
'
,
''
).
replace
(
' '
,
''
)
seq
=
[]
j
=
0
while
(
j
<
len
(
phons
)):
if
(
phons
[
j
]
>
'Z'
):
if
(
phons
[
j
]
==
'j'
):
seq
.
append
(
'JH'
)
elif
(
phons
[
j
]
==
'h'
):
seq
.
append
(
'HH'
)
else
:
seq
.
append
(
phons
[
j
].
upper
())
j
+=
1
else
:
p
=
phons
[
j
:
j
+
2
]
if
(
p
==
'WH'
):
seq
.
append
(
'W'
)
elif
(
p
in
[
'TH'
,
'SH'
,
'HH'
,
'DH'
,
'CH'
,
'ZH'
,
'NG'
]):
seq
.
append
(
p
)
elif
(
p
==
'AX'
):
seq
.
append
(
'AH0'
)
else
:
seq
.
append
(
p
+
'1'
)
j
+=
2
phns
.
extend
(
seq
)
return
phns
def
words2phns
(
line
:
str
):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile
=
MODEL_DIR_EN
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
get_unk_phns
(
wrd
)
phns
.
extend
(
get_unk_phns
(
wrd
))
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
.
upper
()]
=
word2phns_dict
[
wrd
.
upper
()].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
.
upper
()].
split
())
return
phns
,
wrd2phns
def
words2phns_zh
(
line
:
str
):
dictfile
=
MODEL_DIR_ZH
+
'/dict'
line
=
line
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
line
=
line
.
replace
(
pun
,
' '
)
for
wrd
in
line
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
ds
=
set
([])
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
word
=
line
.
split
()[
0
]
ds
.
add
(
word
)
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
" "
.
join
(
line
.
split
()[
1
:])
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
wrd
==
'[MASK]'
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
[
wrd
]
phns
.
append
(
wrd
)
elif
(
wrd
.
upper
()
not
in
ds
):
print
(
"出现非法词错误,请输入正确的文本..."
)
else
:
wrd2phns
[
str
(
index
)
+
"_"
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
prep_txt_zh
(
line
:
str
,
tmpbase
:
str
,
dictfile
:
str
):
words
=
[]
...
...
@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try
:
os
.
system
(
PHONEME
+
' '
+
tmpbase
+
'_unk.words'
+
' '
+
tmpbase
+
'_unk.phons'
)
except
:
except
Exception
:
print
(
'english2phoneme error!'
)
sys
.
exit
(
1
)
...
...
@@ -148,19 +280,22 @@ def _get_user():
def
alignment
(
wav_path
:
str
,
text
:
str
):
'''
intervals: List[phn, start, end]
'''
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
try
:
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
prep_txt_en
(
text
,
tmpbase
,
MODEL_DIR_EN
+
'/dict'
)
except
:
prep_txt_en
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_EN
+
'/dict'
)
except
Exception
:
print
(
'prep_txt error!'
)
return
None
...
...
@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
...
...
@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_EN
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
return
None
...
...
@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+
'/16000/hmmdefs -i '
+
tmpbase
+
'.aligned '
+
tmpbase
+
'.dict '
+
MODEL_DIR_EN
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
return
None
...
...
@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with
open
(
tmpbase
+
'.aligned'
,
'r'
)
as
fid
:
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
...
...
@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
...
@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
def
alignment_zh
(
wav_path
,
text_string
):
def
alignment_zh
(
wav_path
:
str
,
text
:
str
):
tmpbase
=
'/tmp/'
+
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
#prepare wav and trs files
...
...
@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os
.
system
(
'sox '
+
wav_path
+
' -r 16000 -b 16 '
+
tmpbase
+
'.wav remix -'
)
except
:
except
Exception
:
print
(
'sox error!'
)
return
None
#prepare clean_transcript file
try
:
unk_words
=
prep_txt_zh
(
text_string
,
tmpbase
,
MODEL_DIR_ZH
+
'/dict'
)
unk_words
=
prep_txt_zh
(
line
=
text
,
tmpbase
=
tmpbase
,
dictfile
=
MODEL_DIR_ZH
+
'/dict'
)
if
unk_words
:
print
(
'Error! Please add the following words to dictionary:'
)
for
unk
in
unk_words
:
print
(
"非法words: "
,
unk
)
except
:
except
Exception
:
print
(
'prep_txt error!'
)
return
None
...
...
@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with
open
(
tmpbase
+
'.txt'
,
'r'
)
as
fid
:
txt
=
fid
.
readline
()
prep_mlf
(
txt
,
tmpbase
)
except
:
except
Exception
:
print
(
'prep_mlf error!'
)
return
None
...
...
@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try
:
os
.
system
(
HCOPY
+
' -C '
+
MODEL_DIR_ZH
+
'/16000/config '
+
tmpbase
+
'.wav'
+
' '
+
tmpbase
+
'.plp'
)
except
:
except
Exception
:
print
(
'HCopy error!'
)
return
None
...
...
@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+
'/dict '
+
MODEL_DIR_ZH
+
'/monophones '
+
tmpbase
+
'.plp 2>&1 > /dev/null'
)
except
:
except
Exception
:
print
(
'HVite error!'
)
return
None
...
...
@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines
=
fid
.
readlines
()
i
=
2
times2
=
[]
intervals
=
[]
word2phns
=
{}
current_word
=
''
index
=
0
...
...
@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn
=
splited_line
[
2
]
pst
=
(
int
(
splited_line
[
0
])
/
1000
+
125
)
/
10000
pen
=
(
int
(
splited_line
[
1
])
/
1000
+
125
)
/
10000
times2
.
append
([
phn
,
pst
,
pen
])
intervals
.
append
([
phn
,
pst
,
pen
])
# splited_line[-1]!='sp'
if
len
(
splited_line
)
==
5
:
current_word
=
str
(
index
)
+
'_'
+
splited_line
[
-
1
]
...
...
@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif
len
(
splited_line
)
==
4
:
word2phns
[
current_word
]
+=
' '
+
phn
i
+=
1
return
times2
,
word2phns
return
intervals
,
word2phns
ernie-sat/collect_fn.py
0 → 100644
浏览文件 @
30986264
from
typing
import
Collection
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Union
import
numpy
as
np
import
paddle
from
dataset
import
get_seg_pos
from
dataset
import
phones_masking
from
dataset
import
phones_text_masking
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
from
paddlespeech.t2s.modules.nets_utils
import
pad_list
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
seg_emb
=
seg_emb
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
seg_emb
=
self
.
seg_emb
,
text_masking
=
self
.
text_masking
)
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
seg_emb
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lens"
)
for
k
in
data
[
0
]),
f
"*_lens is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
(
[
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
int64
)
output
[
key
+
"_lens"
]
=
lens
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
feats_lens
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
text
=
output
[
"text"
]
text_lens
=
output
[
"text_lens"
]
align_start
=
output
[
"align_start"
]
align_start_lens
=
output
[
"align_start_lens"
]
align_end
=
output
[
"align_end"
]
max_tlen
=
max
(
text_lens
)
max_slen
=
max
(
feats_lens
)
speech_pad
=
feats
[:,
:
max_slen
]
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lens
,
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
speech_mask
=
make_non_pad_mask
(
feats_lens
,
speech_pad
[:,
:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_bdy
=
None
if
'span_bdy'
in
output
.
keys
():
span_bdy
=
output
[
'span_bdy'
]
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if
text_masking
:
masked_pos
,
text_masked_pos
=
phones_text_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
text_pad
=
text_pad
,
text_mask
=
text_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else
:
masked_pos
=
phones_masking
(
xs_pad
=
speech_pad
,
src_mask
=
speech_mask
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
,
span_bdy
=
span_bdy
)
text_masked_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
output_dict
=
{}
speech_seg_pos
,
text_seg_pos
=
get_seg_pos
(
speech_pad
=
speech_pad
,
text_pad
=
text_pad
,
align_start
=
align_start
,
align_end
=
align_end
,
align_start_lens
=
align_start_lens
,
seg_emb
=
seg_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_pos'
]
=
masked_pos
output_dict
[
'text_masked_pos'
]
=
text_masked_pos
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_seg_pos'
]
=
speech_seg_pos
output_dict
[
'text_seg_pos'
]
=
text_seg_pos
output
=
(
uttids
,
output_dict
)
return
output
def
build_collate_fn
(
sr
:
int
=
24000
,
n_fft
:
int
=
2048
,
hop_length
:
int
=
300
,
win_length
:
int
=
None
,
n_mels
:
int
=
80
,
fmin
:
int
=
80
,
fmax
:
int
=
7600
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
train
:
bool
=
False
,
seg_emb
:
bool
=
False
,
epoch
:
int
=-
1
,
):
feats_extract_class
=
LogMelFBank
feats_extract
=
feats_extract_class
(
sr
=
sr
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
pad_speech
=
False
if
epoch
==
-
1
:
mlm_prob_factor
=
1
else
:
mlm_prob_factor
=
0.8
return
MLMCollateFn
(
feats_extract
=
feats_extract
,
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
mlm_prob
*
mlm_prob_factor
,
mean_phn_span
=
mean_phn_span
,
pad_speech
=
pad_speech
,
seg_emb
=
seg_emb
)
ernie-sat/dataset.py
浏览文件 @
30986264
...
...
@@ -4,6 +4,68 @@ import numpy as np
import
paddle
# mask phones
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
# mask speech and phones
def
phones_text_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
...
...
@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
float
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
span_bdy
:
paddle
.
Tensor
=
None
):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_mask_num_lower
=
math
.
ceil
(
text_len
*
(
1
-
mlm_prob
)
*
0.5
)
text_masked_pos
=
paddle
.
zeros
((
bz
,
text_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_phn_idxs
=
random_spans_noise_mask
(
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
# for inference
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
# for training
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
length
=
length
,
mlm_prob
=
mlm_prob
,
mean_phn_span
=
mean_phn_span
).
nonzero
()
unmasked_phn_idxs
=
list
(
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
...
...
@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
return
masked_pos
,
text_masked_pos
,
y_masks
return
masked_pos
,
text_masked_pos
def
get_seg_pos_reduce_duration
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
,
masked_pos
:
paddle
.
Tensor
,
feats_lens
:
paddle
.
Tensor
,
):
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
seg_emb
:
bool
=
False
):
'''
Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
text_seg_pos
=
paddle
.
zeros
(
paddle
.
shape
(
text_pad
))
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
reordered_idx
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
durations
=
paddle
.
ones
((
bz
,
speech_len
),
dtype
=
align_start_lens
.
dtype
)
max_reduced_length
=
0
if
not
sega_emb
:
return
speech_pad
,
masked_pos
,
speech_seg_pos
,
text_seg_pos
,
durations
if
not
seg_emb
:
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
first_idx
=
[]
last_idx
=
[]
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
if
j
==
0
:
if
paddle
.
sum
(
masked_pos
[
idx
][
0
:
s
])
==
0
:
first_idx
.
extend
(
range
(
0
,
s
))
else
:
first_idx
.
extend
([
0
])
last_idx
.
extend
(
range
(
1
,
s
))
if
paddle
.
sum
(
masked_pos
[
idx
][
s
:
e
])
==
0
:
first_idx
.
extend
(
range
(
s
,
e
))
else
:
first_idx
.
extend
([
s
])
last_idx
.
extend
(
range
(
s
+
1
,
e
))
durations
[
idx
][
s
]
=
e
-
s
speech_seg_pos
[
idx
][
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
][
j
]
=
j
+
1
max_reduced_length
=
max
(
len
(
first_idx
)
+
feats_lens
[
idx
]
-
e
,
max_reduced_length
)
first_idx
.
extend
(
range
(
e
,
speech_len
))
reordered_idx
[
idx
]
=
paddle
.
to_tensor
(
(
first_idx
+
last_idx
),
dtype
=
align_start_lens
.
dtype
)
feats_lens
[
idx
]
=
len
(
first_idx
)
reordered_idx
=
reordered_idx
[:,
:
max_reduced_length
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
reordered_idx
,
speech_seg_pos
,
text_seg_pos
,
durations
,
feats_len
s
return
speech_seg_pos
,
text_seg_po
s
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
,
mean_phn_span
:
float
):
# randomly select the range of speech and text to mask during training
def
random_spans_noise_mask
(
length
:
int
,
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
float
=
8
):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
...
...
@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
np.ndarray:
a boolean tensor with shape [length]
"""
orig_length
=
length
...
...
@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise
=
np
.
equal
(
span_num
%
2
,
1
)
return
is_noise
[:
orig_length
]
def
pad_to_longformer_att_window
(
text
:
paddle
.
Tensor
,
max_len
:
int
,
max_tlen
:
int
,
attention_window
:
int
=
0
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
(
(
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
def
phones_masking
(
xs_pad
:
paddle
.
Tensor
,
src_mask
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
mlm_prob
:
float
,
mean_phn_span
:
int
,
span_bdy
:
paddle
.
Tensor
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
masked_pos
=
paddle
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
if
mlm_prob
==
1.0
:
masked_pos
+=
1
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_pos
[:,
masked_phn_idxs
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_bdy
is
not
None
:
for
s
,
e
in
zip
(
span_bdy
[
idx
][::
2
],
span_bdy
[
idx
][
1
::
2
]):
masked_pos
[
idx
,
s
:
e
]
=
1
else
:
length
=
align_start_lens
[
idx
]
if
length
<
2
:
continue
masked_phn_idxs
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
return
masked_pos
,
y_masks
def
get_seg_pos
(
speech_pad
:
paddle
.
Tensor
,
text_pad
:
paddle
.
Tensor
,
align_start
:
paddle
.
Tensor
,
align_end
:
paddle
.
Tensor
,
align_start_lens
:
paddle
.
Tensor
,
sega_emb
:
bool
):
bz
,
speech_len
,
_
=
paddle
.
shape
(
speech_pad
)
_
,
text_len
=
paddle
.
shape
(
text_pad
)
text_seg_pos
=
paddle
.
zeros
((
bz
,
text_len
),
dtype
=
'int64'
)
speech_seg_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
'int64'
)
if
not
sega_emb
:
return
speech_seg_pos
,
text_seg_pos
for
idx
in
range
(
bz
):
align_length
=
align_start_lens
[
idx
]
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
],
align_end
[
idx
][
j
]
speech_seg_pos
[
idx
,
s
:
e
]
=
j
+
1
text_seg_pos
[
idx
,
j
]
=
j
+
1
return
speech_seg_pos
,
text_seg_pos
ernie-sat/inference.py
浏览文件 @
30986264
此差异已折叠。
点击以展开。
ernie-sat/m
odel_paddle
.py
→
ernie-sat/m
lm
.py
浏览文件 @
30986264
此差异已折叠。
点击以展开。
ernie-sat/mlm_loss.py
0 → 100644
浏览文件 @
30986264
import
paddle
from
paddle
import
nn
class
MLMLoss
(
nn
.
Layer
):
def
__init__
(
self
,
lsm_weight
:
float
=
0.1
,
ignore_id
:
int
=-
1
,
text_masking
:
bool
=
False
):
super
().
__init__
()
if
text_masking
:
self
.
text_mlm_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_id
)
if
lsm_weight
>
50
:
self
.
l1_loss_func
=
nn
.
MSELoss
()
else
:
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
text_masking
=
text_masking
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
before_outs
:
paddle
.
Tensor
,
after_outs
:
paddle
.
Tensor
,
masked_pos
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
=
None
,
text_outs
:
paddle
.
Tensor
=
None
,
text_masked_pos
:
paddle
.
Tensor
=
None
):
xs_pad
=
speech
mlm_loss_pos
=
masked_pos
>
0
loss
=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
before_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
if
after_outs
is
not
None
:
loss
+=
paddle
.
sum
(
self
.
l1_loss_func
(
paddle
.
reshape
(
after_outs
,
(
-
1
,
self
.
odim
)),
paddle
.
reshape
(
xs_pad
,
(
-
1
,
self
.
odim
))),
axis
=-
1
)
loss_mlm
=
paddle
.
sum
((
loss
*
paddle
.
reshape
(
mlm_loss_pos
,
[
-
1
])))
/
paddle
.
sum
((
mlm_loss_pos
)
+
1e-10
)
if
self
.
text_masking
:
loss_text
=
paddle
.
sum
((
self
.
text_mlm_loss
(
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
paddle
.
reshape
(
text
,
(
-
1
)))
*
paddle
.
reshape
(
text_masked_pos
,
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
loss_mlm
,
loss_text
return
loss_mlm
ernie-sat/paddlespeech/t2s/modules/transformer/attention.py
浏览文件 @
30986264
...
...
@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
class
LegacyRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
zero_triu
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
self
.
zero_triu
=
zero_triu
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias_attr
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
pos_bias_v
=
paddle
.
create_parameter
(
shape
=
(
self
.
h
,
self
.
d_k
),
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
def
rel_shift
(
self
,
x
):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b
,
h
,
t1
,
t2
=
paddle
.
shape
(
x
)
zero_pad
=
paddle
.
zeros
((
b
,
h
,
t1
,
1
))
x_padded
=
paddle
.
concat
([
zero_pad
,
x
],
axis
=-
1
)
x_padded
=
paddle
.
reshape
(
x_padded
,
[
b
,
h
,
t2
+
1
,
t1
])
# only keep the positions from 0 to time2
x
=
paddle
.
reshape
(
x_padded
[:,
:,
1
:],
[
b
,
h
,
t1
,
t2
])
if
self
.
zero_triu
:
ones
=
paddle
.
ones
((
t1
,
t2
))
x
=
x
*
paddle
.
tril
(
ones
,
t2
-
1
)[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
,
key
,
value
,
pos_emb
,
mask
):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# (batch, time1, head, d_k)
q
=
paddle
.
transpose
(
q
,
[
0
,
2
,
1
,
3
])
n_batch_pos
=
paddle
.
shape
(
pos_emb
)[
0
]
p
=
paddle
.
reshape
(
self
.
linear_pos
(
pos_emb
),
[
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
])
# (batch, head, time1, d_k)
p
=
paddle
.
transpose
(
p
,
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_u
=
paddle
.
transpose
((
q
+
self
.
pos_bias_u
),
[
0
,
2
,
1
,
3
])
# (batch, head, time1, d_k)
q_with_bias_v
=
paddle
.
transpose
((
q
+
self
.
pos_bias_v
),
[
0
,
2
,
1
,
3
])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
paddle
.
matmul
(
q_with_bias_u
,
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd
=
paddle
.
matmul
(
q_with_bias_v
,
paddle
.
transpose
(
p
,
[
0
,
1
,
3
,
2
]))
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
# (batch, head, time1, time2)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
)
ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py
浏览文件 @
30986264
...
...
@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size
=
paddle
.
shape
(
self
.
pe
)
pos_emb
=
self
.
pe
[:,
pe_size
[
1
]
//
2
-
T
+
1
:
pe_size
[
1
]
//
2
+
T
,
]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
class
LegacyRelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
extend_pe
(
self
,
x
):
"""Reset the positional encodings."""
if
self
.
pe
is
not
None
:
if
paddle
.
shape
(
self
.
pe
)[
1
]
>=
paddle
.
shape
(
x
)[
1
]:
return
pe
=
paddle
.
zeros
((
paddle
.
shape
(
x
)[
1
],
self
.
d_model
))
if
self
.
reverse
:
position
=
paddle
.
arange
(
paddle
.
shape
(
x
)[
1
]
-
1
,
-
1
,
-
1.0
,
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
else
:
position
=
paddle
.
arange
(
0
,
paddle
.
shape
(
x
)[
1
],
dtype
=
paddle
.
float32
).
unsqueeze
(
1
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
paddle
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
pe
=
pe
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
extend_pe
(
x
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
:
paddle
.
shape
(
x
)[
1
]]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
ernie-sat/read_text.py
浏览文件 @
30986264
...
...
@@ -5,7 +5,7 @@ from typing import List
from
typing
import
Union
def
read_2col
umn
_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
def
read_2col_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
"""Read a text file having 2 column as dict object.
Examples:
...
...
@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col
umn
_text('wav.scp')
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
...
...
ernie-sat/sedit_arg_parser.py
浏览文件 @
30986264
...
...
@@ -65,12 +65,6 @@ def parse_args():
help
=
"mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser
.
add_argument
(
'--lang'
,
type
=
str
,
default
=
'en'
,
help
=
'Choose model language. zh or en'
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
# parser.add_argument("--test_metadata", type=str, help="test metadata.")
...
...
ernie-sat/utils.py
浏览文件 @
30986264
import
os
from
typing
import
List
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -55,16 +54,14 @@ def build_vocoder_from_file(
raise
ValueError
(
f
"
{
vocoder_file
}
is not supported format."
)
def
get_voc_out
(
mel
,
target_lang
:
str
=
"chinese"
):
def
get_voc_out
(
mel
):
# vocoder
args
=
parse_args
()
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_inference
=
voc_inference
=
get_voc_inference
(
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
...
...
@@ -167,19 +164,23 @@ def get_voc_inference(
return
voc_inference
def
evaluate_durations
(
phns
:
List
[
str
],
target_lang
:
str
=
"chinese"
,
fs
:
int
=
24000
,
hop_length
:
int
=
300
):
def
eval_durs
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
args
=
parse_args
()
if
target_lang
==
'english'
:
args
.
lang
=
'en'
args
.
am
=
"fastspeech2_ljspeech"
args
.
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args
.
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args
.
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
==
'chinese'
:
args
.
lang
=
'zh'
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
...
...
@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else
:
print
(
"ngpu should >= 0 !"
)
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
"In evaluate_durations function, target_lang is illegal..."
# Init body.
with
open
(
args
.
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
...
...
@@ -203,22 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict
=
args
.
speaker_dict
,
return_am
=
True
)
torch_phns
=
phns
vocab_phones
=
{}
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
torch_
phns
]
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids_new
=
phone_ids
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
normalized_mel
,
d_outs
,
p_outs
,
e_outs
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
phone_ids
.
append
(
vocab_size
-
1
)
phone_ids
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids
,
np
.
int64
))
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
ph
oneme_duration
s_new
=
pre_d_outs
*
hop_length
/
fs
ph
oneme_durations_new
=
phoneme_duration
s_new
.
tolist
()[:
-
1
]
return
ph
oneme_duration
s_new
ph
u_dur
s_new
=
pre_d_outs
*
hop_length
/
fs
ph
u_durs_new
=
phu_dur
s_new
.
tolist
()[:
-
1
]
return
ph
u_dur
s_new
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录