Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9d4161ce
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9d4161ce
编写于
7月 27, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update config, test=doc
上级
97965f4c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
956 addition
and
8 deletion
+956
-8
examples/aishell3/ernie_sat/conf/default.yaml
examples/aishell3/ernie_sat/conf/default.yaml
+3
-3
examples/aishell3_vctk/ernie_sat/conf/default.yaml
examples/aishell3_vctk/ernie_sat/conf/default.yaml
+3
-3
examples/vctk/ernie_sat/conf/default.yaml
examples/vctk/ernie_sat/conf/default.yaml
+2
-2
paddlespeech/t2s/exps/ernie_sat/align.py
paddlespeech/t2s/exps/ernie_sat/align.py
+386
-0
paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
+346
-0
paddlespeech/t2s/exps/ernie_sat/utils.py
paddlespeech/t2s/exps/ernie_sat/utils.py
+216
-0
未找到文件。
examples/aishell3/ernie_sat/conf/default.yaml
浏览文件 @
9d4161ce
...
...
@@ -79,13 +79,13 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
6
00
num_snapshots
:
5
max_epoch
:
15
00
num_snapshots
:
5
0
###########################################################
# OTHER SETTING #
###########################################################
seed
:
10086
seed
:
0
token_list
:
-
<blank>
...
...
examples/aishell3_vctk/ernie_sat/conf/default.yaml
浏览文件 @
9d4161ce
...
...
@@ -79,13 +79,13 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
3
00
num_snapshots
:
5
max_epoch
:
7
00
num_snapshots
:
5
0
###########################################################
# OTHER SETTING #
###########################################################
seed
:
10086
seed
:
0
token_list
:
-
<blank>
...
...
examples/vctk/ernie_sat/conf/default.yaml
浏览文件 @
9d4161ce
...
...
@@ -79,8 +79,8 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
6
00
num_snapshots
:
5
max_epoch
:
15
00
num_snapshots
:
5
0
###########################################################
# OTHER SETTING #
...
...
paddlespeech/t2s/exps/ernie_sat/align.py
0 → 100755
浏览文件 @
9d4161ce
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
pathlib
import
Path
import
librosa
import
numpy
as
np
import
pypinyin
from
praatio
import
textgrid
from
paddlespeech.t2s.exps.ernie_sat.utils
import
get_tmp_name
from
paddlespeech.t2s.exps.ernie_sat.utils
import
get_dict
DICT_EN
=
'tools/aligner/cmudict-0.7b'
DICT_ZH
=
'tools/aligner/simple.lexicon'
MODEL_DIR_EN
=
'tools/aligner/vctk_model.zip'
MODEL_DIR_ZH
=
'tools/aligner/aishell3_model.zip'
MFA_PATH
=
'tools/montreal-forced-aligner/bin'
os
.
environ
[
'PATH'
]
=
MFA_PATH
+
'/:'
+
os
.
environ
[
'PATH'
]
def
_get_max_idx
(
dic
):
return
sorted
([
int
(
key
.
split
(
'_'
)[
0
])
for
key
in
dic
.
keys
()])[
-
1
]
def
_readtg
(
tg_path
:
str
,
lang
:
str
=
'en'
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
):
alignment
=
textgrid
.
openTextgrid
(
tg_path
,
includeEmptyIntervals
=
True
)
phones
=
[]
ends
=
[]
words
=
[]
for
interval
in
alignment
.
tierDict
[
'words'
].
entryList
:
word
=
interval
.
label
if
word
:
words
.
append
(
word
)
for
interval
in
alignment
.
tierDict
[
'phones'
].
entryList
:
phone
=
interval
.
label
phones
.
append
(
phone
)
ends
.
append
(
interval
.
end
)
frame_pos
=
librosa
.
time_to_frames
(
ends
,
sr
=
fs
,
hop_length
=
n_shift
)
durations
=
np
.
diff
(
frame_pos
,
prepend
=
0
)
assert
len
(
durations
)
==
len
(
phones
)
# merge '' and sp in the end
if
phones
[
-
1
]
==
''
and
len
(
phones
)
>
1
and
phones
[
-
2
]
==
'sp'
:
phones
=
phones
[:
-
1
]
durations
[
-
2
]
+=
durations
[
-
1
]
durations
=
durations
[:
-
1
]
# replace ' and 'sil' with 'sp'
phones
=
[
'sp'
if
(
phn
==
''
or
phn
==
'sil'
)
else
phn
for
phn
in
phones
]
if
lang
==
'en'
:
DICT
=
DICT_EN
elif
lang
==
'zh'
:
DICT
=
DICT_ZH
word2phns_dict
=
get_dict
(
DICT
)
phn2word_dict
=
[]
for
word
in
words
:
if
lang
==
'en'
:
word
=
word
.
upper
()
phn2word_dict
.
append
([
word2phns_dict
[
word
].
split
(),
word
])
non_sp_idx
=
0
word_idx
=
0
i
=
0
word2phns
=
{}
while
i
<
len
(
phones
):
phn
=
phones
[
i
]
if
phn
==
'sp'
:
word2phns
[
str
(
word_idx
)
+
'_sp'
]
=
[
'sp'
]
i
+=
1
else
:
phns
,
word
=
phn2word_dict
[
non_sp_idx
]
word2phns
[
str
(
word_idx
)
+
'_'
+
word
]
=
phns
non_sp_idx
+=
1
i
+=
len
(
phns
)
word_idx
+=
1
sum_phn
=
sum
(
len
(
word2phns
[
k
])
for
k
in
word2phns
)
assert
sum_phn
==
len
(
phones
)
results
=
''
for
(
p
,
d
)
in
zip
(
phones
,
durations
):
results
+=
p
+
' '
+
str
(
d
)
+
' '
return
results
.
strip
(),
word2phns
def
alignment
(
wav_path
:
str
,
text
:
str
,
fs
:
int
=
24000
,
lang
=
'en'
,
n_shift
:
int
=
300
):
wav_name
=
os
.
path
.
basename
(
wav_path
)
utt
=
wav_name
.
split
(
'.'
)[
0
]
# prepare data for MFA
tmp_name
=
get_tmp_name
(
text
=
text
)
tmpbase
=
'./tmp_dir/'
+
tmp_name
tmpbase
=
Path
(
tmpbase
)
tmpbase
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
print
(
"tmp_name in alignment:"
,
tmp_name
)
shutil
.
copyfile
(
wav_path
,
tmpbase
/
wav_name
)
txt_name
=
utt
+
'.txt'
txt_path
=
tmpbase
/
txt_name
with
open
(
txt_path
,
'w'
)
as
wf
:
wf
.
write
(
text
+
'
\n
'
)
# MFA
if
lang
==
'en'
:
DICT
=
DICT_EN
MODEL_DIR
=
MODEL_DIR_EN
elif
lang
==
'zh'
:
DICT
=
DICT_ZH
MODEL_DIR
=
MODEL_DIR_ZH
else
:
print
(
'please input right lang!!'
)
CMD
=
'mfa_align'
+
' '
+
str
(
tmpbase
)
+
' '
+
DICT
+
' '
+
MODEL_DIR
+
' '
+
str
(
tmpbase
)
os
.
system
(
CMD
)
tg_path
=
str
(
tmpbase
)
+
'/'
+
tmp_name
+
'/'
+
utt
+
'.TextGrid'
phn_dur
,
word2phns
=
_readtg
(
tg_path
,
lang
=
lang
)
phn_dur
=
phn_dur
.
split
()
phns
=
phn_dur
[::
2
]
durs
=
phn_dur
[
1
::
2
]
durs
=
[
int
(
d
)
for
d
in
durs
]
assert
len
(
phns
)
==
len
(
durs
)
return
phns
,
durs
,
word2phns
def
words2phns
(
text
:
str
,
lang
=
'en'
):
'''
Args:
text (str):
input text.
eg: for that reason cover is impossible to be given.
lang (str):
'en' or 'zh'
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'],
'2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'],
'5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
text
=
text
.
strip
()
words
=
[]
for
pun
in
[
','
,
'.'
,
':'
,
';'
,
'!'
,
'?'
,
'"'
,
'('
,
')'
,
'--'
,
'---'
,
u
','
,
u
'。'
,
u
':'
,
u
';'
,
u
'!'
,
u
'?'
,
u
'('
,
u
')'
]:
text
=
text
.
replace
(
pun
,
' '
)
for
wrd
in
text
.
split
():
if
(
wrd
[
-
1
]
==
'-'
):
wrd
=
wrd
[:
-
1
]
if
(
wrd
[
0
]
==
"'"
):
wrd
=
wrd
[
1
:]
if
wrd
:
words
.
append
(
wrd
)
if
lang
==
'en'
:
dictfile
=
DICT_EN
elif
lang
==
'zh'
:
dictfile
=
DICT_ZH
else
:
print
(
'please input right lang!!'
)
word2phns_dict
=
get_dict
(
dictfile
)
ds
=
word2phns_dict
.
keys
()
phns
=
[]
wrd2phns
=
{}
for
index
,
wrd
in
enumerate
(
words
):
if
lang
==
'en'
:
wrd
=
wrd
.
upper
()
if
(
wrd
not
in
ds
):
wrd2phns
[
str
(
index
)
+
'_'
+
wrd
]
=
'spn'
phns
.
extend
(
'spn'
)
else
:
wrd2phns
[
str
(
index
)
+
'_'
+
wrd
]
=
word2phns_dict
[
wrd
].
split
()
phns
.
extend
(
word2phns_dict
[
wrd
].
split
())
return
phns
,
wrd2phns
def
get_phns_spans
(
wav_path
:
str
,
old_str
:
str
=
''
,
new_str
:
str
=
''
,
source_lang
:
str
=
'en'
,
target_lang
:
str
=
'en'
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
):
is_append
=
(
old_str
==
new_str
[:
len
(
old_str
)])
old_phns
,
mfa_start
,
mfa_end
=
[],
[],
[]
# source
lang
=
source_lang
phn
,
dur
,
w2p
=
alignment
(
wav_path
=
wav_path
,
text
=
old_str
,
lang
=
lang
,
fs
=
fs
,
n_shift
=
n_shift
)
new_d_cumsum
=
np
.
pad
(
np
.
array
(
dur
).
cumsum
(
0
),
(
1
,
0
),
'constant'
).
tolist
()
mfa_start
=
new_d_cumsum
[:
-
1
]
mfa_end
=
new_d_cumsum
[
1
:]
old_phns
=
phn
# target
if
is_append
and
(
source_lang
!=
target_lang
):
cross_lingual_clone
=
True
else
:
cross_lingual_clone
=
False
if
cross_lingual_clone
:
str_origin
=
new_str
[:
len
(
old_str
)]
str_append
=
new_str
[
len
(
old_str
):]
if
target_lang
==
'zh'
:
phns_origin
,
origin_w2p
=
words2phns
(
str_origin
,
lang
=
'en'
)
phns_append
,
append_w2p_tmp
=
words2phns
(
str_append
,
lang
=
'zh'
)
elif
target_lang
==
'en'
:
# 原始句子
phns_origin
,
origin_w2p
=
words2phns
(
str_origin
,
lang
=
'zh'
)
# clone 句子
phns_append
,
append_w2p_tmp
=
words2phns
(
str_append
,
lang
=
'en'
)
else
:
assert
target_lang
==
'zh'
or
target_lang
==
'en'
,
\
'cloning is not support for this language, please check it.'
new_phns
=
phns_origin
+
phns_append
append_w2p
=
{}
length
=
len
(
origin_w2p
)
for
key
,
value
in
append_w2p_tmp
.
items
():
idx
,
wrd
=
key
.
split
(
'_'
)
append_w2p
[
str
(
int
(
idx
)
+
length
)
+
'_'
+
wrd
]
=
value
new_w2p
=
origin_w2p
.
copy
()
new_w2p
.
update
(
append_w2p
)
else
:
if
source_lang
==
target_lang
:
new_phns
,
new_w2p
=
words2phns
(
new_str
,
lang
=
source_lang
)
else
:
assert
source_lang
==
target_lang
,
\
'source language is not same with target language...'
span_to_repl
=
[
0
,
len
(
old_phns
)
-
1
]
span_to_add
=
[
0
,
len
(
new_phns
)
-
1
]
left_idx
=
0
new_phns_left
=
[]
sp_count
=
0
# find the left different index
# 因为可能 align 时候的 words2phns 和直接 words2phns, 前者会有 sp?
for
key
in
w2p
.
keys
():
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_left
.
append
(
'sp'
)
else
:
idx
=
str
(
int
(
idx
)
-
sp_count
)
if
idx
+
'_'
+
wrd
in
new_w2p
:
# 是 new_str phn 序列的 index
left_idx
+=
len
(
new_w2p
[
idx
+
'_'
+
wrd
])
# old phn 序列
new_phns_left
.
extend
(
w2p
[
key
])
else
:
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
break
# reverse w2p and new_w2p
right_idx
=
0
new_phns_right
=
[]
sp_count
=
0
w2p_max_idx
=
_get_max_idx
(
w2p
)
new_w2p_max_idx
=
_get_max_idx
(
new_w2p
)
new_phns_mid
=
[]
if
is_append
:
new_phns_right
=
[]
new_phns_mid
=
new_phns
[
left_idx
:]
span_to_repl
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
0
]
=
len
(
new_phns_left
)
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
# speech edit
else
:
for
key
in
list
(
w2p
.
keys
())[::
-
1
]:
idx
,
wrd
=
key
.
split
(
'_'
)
if
wrd
==
'sp'
:
sp_count
+=
1
new_phns_right
=
[
'sp'
]
+
new_phns_right
else
:
idx
=
str
(
new_w2p_max_idx
-
(
w2p_max_idx
-
int
(
idx
)
-
sp_count
))
if
idx
+
'_'
+
wrd
in
new_w2p
:
right_idx
-=
len
(
new_w2p
[
idx
+
'_'
+
wrd
])
new_phns_right
=
w2p
[
key
]
+
new_phns_right
else
:
span_to_repl
[
1
]
=
len
(
old_phns
)
-
len
(
new_phns_right
)
new_phns_mid
=
new_phns
[
left_idx
:
right_idx
]
span_to_add
[
1
]
=
len
(
new_phns_left
)
+
len
(
new_phns_mid
)
if
len
(
new_phns_mid
)
==
0
:
span_to_add
[
1
]
=
min
(
span_to_add
[
1
]
+
1
,
len
(
new_phns
))
span_to_add
[
0
]
=
max
(
0
,
span_to_add
[
0
]
-
1
)
span_to_repl
[
0
]
=
max
(
0
,
span_to_repl
[
0
]
-
1
)
span_to_repl
[
1
]
=
min
(
span_to_repl
[
1
]
+
1
,
len
(
old_phns
))
break
new_phns
=
new_phns_left
+
new_phns_mid
+
new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
outs
=
{}
outs
[
'mfa_start'
]
=
mfa_start
outs
[
'mfa_end'
]
=
mfa_end
outs
[
'old_phns'
]
=
old_phns
outs
[
'new_phns'
]
=
new_phns
outs
[
'span_to_repl'
]
=
span_to_repl
outs
[
'span_to_add'
]
=
span_to_add
return
outs
if
__name__
==
'__main__'
:
text
=
"For that reason cover should not be given."
phn
,
dur
,
word2phns
=
alignment
(
"exp/p243_313.wav"
,
text
,
lang
=
'en'
)
print
(
phn
,
dur
)
print
(
word2phns
)
print
(
"---------------------------------"
)
# 这里可以用我们的中文前端得到 pinyin 序列
text_zh
=
"卡尔普陪外孙玩滑梯。"
text_zh
=
pypinyin
.
lazy_pinyin
(
text_zh
,
neutral_tone_with_five
=
True
,
style
=
pypinyin
.
Style
.
TONE3
,
tone_sandhi
=
True
)
text_zh
=
" "
.
join
(
text_zh
)
phn
,
dur
,
word2phns
=
alignment
(
"exp/000001.wav"
,
text_zh
,
lang
=
'zh'
)
print
(
phn
,
dur
)
print
(
word2phns
)
print
(
"---------------------------------"
)
phns
,
wrd2phns
=
words2phns
(
text
,
lang
=
'en'
)
print
(
"phns:"
,
phns
)
print
(
"wrd2phns:"
,
wrd2phns
)
print
(
"---------------------------------"
)
phns
,
wrd2phns
=
words2phns
(
text_zh
,
lang
=
'zh'
)
print
(
"phns:"
,
phns
)
print
(
"wrd2phns:"
,
wrd2phns
)
print
(
"---------------------------------"
)
outs
=
get_phns_spans
(
wav_path
=
"exp/p243_313.wav"
,
old_str
=
"For that reason cover should not be given."
,
new_str
=
"for that reason cover is impossible to be given."
)
mfa_start
=
outs
[
"mfa_start"
]
mfa_end
=
outs
[
"mfa_end"
]
old_phns
=
outs
[
"old_phns"
]
new_phns
=
outs
[
"new_phns"
]
span_to_repl
=
outs
[
"span_to_repl"
]
span_to_add
=
outs
[
"span_to_add"
]
print
(
"mfa_start:"
,
mfa_start
)
print
(
"mfa_end:"
,
mfa_end
)
print
(
"old_phns:"
,
old_phns
)
print
(
"new_phns:"
,
new_phns
)
print
(
"span_to_repl:"
,
span_to_repl
)
print
(
"span_to_add:"
,
span_to_add
)
print
(
"---------------------------------"
)
paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
0 → 100644
浏览文件 @
9d4161ce
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
librosa
import
numpy
as
np
import
soundfile
as
sf
from
paddlespeech.t2s.exps.ernie_sat.align
import
get_phns_spans
from
paddlespeech.t2s.exps.ernie_sat.utils
import
eval_durs
from
paddlespeech.t2s.exps.ernie_sat.utils
import
get_dur_adj_factor
from
paddlespeech.t2s.exps.ernie_sat.utils
import
get_span_bdy
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_erniesat_collate_fn
from
paddlespeech.t2s.exps.syn_utils
import
get_frontend
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.exps.syn_utils
import
norm
from
paddlespeech.t2s.exps.ernie_sat.utils
import
get_tmp_name
def
_p2id
(
self
,
phonemes
:
List
[
str
])
->
np
.
ndarray
:
# replace unk phone with sp
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phonemes
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
return
np
.
array
(
phone_ids
,
np
.
int64
)
def
prep_feats_with_dur
(
wav_path
:
str
,
old_str
:
str
=
''
,
new_str
:
str
=
''
,
source_lang
:
str
=
'en'
,
target_lang
:
str
=
'en'
,
duration_adjust
:
bool
=
True
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
phns_spans_outs
=
get_phns_spans
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
fs
=
fs
,
n_shift
=
n_shift
)
mfa_start
=
phns_spans_outs
[
"mfa_start"
]
mfa_end
=
phns_spans_outs
[
"mfa_end"
]
old_phns
=
phns_spans_outs
[
"old_phns"
]
new_phns
=
phns_spans_outs
[
"new_phns"
]
span_to_repl
=
phns_spans_outs
[
"span_to_repl"
]
span_to_add
=
phns_spans_outs
[
"span_to_add"
]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
in
{
'en'
,
'zh'
}:
old_durs
=
eval_durs
(
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
in
{
'en'
,
'zh'
},
\
"calculate duration_predict is not support for this language..."
orig_old_durs
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
if
duration_adjust
:
d_factor
=
get_dur_adj_factor
(
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
in
{
'en'
,
'zh'
}:
new_durs
=
eval_durs
(
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"zh"
or
target_lang
==
"en"
,
\
"calculate duration_predict is not support for this language..."
# duration 要是整数
new_durs_adjusted
=
[
int
(
np
.
ceil
(
d_factor
*
i
))
for
i
in
new_durs
]
new_span_dur_sum
=
sum
(
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]])
old_span_dur_sum
=
sum
(
orig_old_durs
[
span_to_repl
[
0
]:
span_to_repl
[
1
]])
dur_offset
=
new_span_dur_sum
-
old_span_dur_sum
new_mfa_start
=
mfa_start
[:
span_to_repl
[
0
]]
new_mfa_end
=
mfa_end
[:
span_to_repl
[
0
]]
for
dur
in
new_durs_adjusted
[
span_to_add
[
0
]:
span_to_add
[
1
]]:
if
len
(
new_mfa_end
)
==
0
:
new_mfa_start
.
append
(
0
)
new_mfa_end
.
append
(
dur
)
else
:
new_mfa_start
.
append
(
new_mfa_end
[
-
1
])
new_mfa_end
.
append
(
new_mfa_end
[
-
1
]
+
dur
)
new_mfa_start
+=
[
i
+
dur_offset
for
i
in
mfa_start
[
span_to_repl
[
1
]:]]
new_mfa_end
+=
[
i
+
dur_offset
for
i
in
mfa_end
[
span_to_repl
[
1
]:]]
# 3. get new wav
# 在原始句子后拼接
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
wav_left_idx
=
len
(
wav_org
)
wav_right_idx
=
wav_left_idx
# 在原始句子中间替换
else
:
wav_left_idx
=
int
(
np
.
floor
(
mfa_start
[
span_to_repl
[
0
]]
*
n_shift
))
wav_right_idx
=
int
(
np
.
ceil
(
mfa_end
[
span_to_repl
[
1
]
-
1
]
*
n_shift
))
blank_wav
=
np
.
zeros
(
(
int
(
np
.
ceil
(
new_span_dur_sum
*
n_shift
)),
),
dtype
=
wav_org
.
dtype
)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav
=
np
.
concatenate
(
[
wav_org
[:
wav_left_idx
],
blank_wav
,
wav_org
[
wav_right_idx
:]])
# 音频是正常遮住了
sf
.
write
(
str
(
"new_wav.wav"
),
new_wav
,
samplerate
=
fs
)
# 4. get old and new mel span to be mask
old_span_bdy
=
get_span_bdy
(
mfa_start
=
mfa_start
,
mfa_end
=
mfa_end
,
span_to_repl
=
span_to_repl
)
new_span_bdy
=
get_span_bdy
(
mfa_start
=
new_mfa_start
,
mfa_end
=
new_mfa_end
,
span_to_repl
=
span_to_add
)
# old_span_bdy, new_span_bdy 是帧级别的范围
outs
=
{}
outs
[
'new_wav'
]
=
new_wav
outs
[
'new_phns'
]
=
new_phns
outs
[
'new_mfa_start'
]
=
new_mfa_start
outs
[
'new_mfa_end'
]
=
new_mfa_end
outs
[
'old_span_bdy'
]
=
old_span_bdy
outs
[
'new_span_bdy'
]
=
new_span_bdy
return
outs
def
prep_feats
(
wav_path
:
str
,
old_str
:
str
=
''
,
new_str
:
str
=
''
,
source_lang
:
str
=
'en'
,
target_lang
:
str
=
'en'
,
duration_adjust
:
bool
=
True
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
):
outs
=
prep_feats_with_dur
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
source_lang
=
source_lang
,
target_lang
=
target_lang
,
duration_adjust
=
duration_adjust
,
fs
=
fs
,
n_shift
=
n_shift
)
wav_name
=
os
.
path
.
basename
(
wav_path
)
utt_id
=
wav_name
.
split
(
'.'
)[
0
]
wav
=
outs
[
'new_wav'
]
phns
=
outs
[
'new_phns'
]
mfa_start
=
outs
[
'new_mfa_start'
]
mfa_end
=
outs
[
'new_mfa_end'
]
old_span_bdy
=
outs
[
'old_span_bdy'
]
new_span_bdy
=
outs
[
'new_span_bdy'
]
span_bdy
=
np
.
array
(
new_span_bdy
)
text
=
_p2id
(
phns
)
mel
=
mel_extractor
.
get_log_mel_fbank
(
wav
)
erniesat_mean
,
erniesat_std
=
np
.
load
(
erniesat_stat
)
normed_mel
=
norm
(
mel
,
erniesat_mean
,
erniesat_std
)
tmp_name
=
get_tmp_name
(
text
=
old_str
)
tmpbase
=
'./tmp_dir/'
+
tmp_name
tmpbase
=
Path
(
tmpbase
)
tmpbase
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
print
(
"tmp_name in synthesize_e2e:"
,
tmp_name
)
mel_path
=
tmpbase
/
'mel.npy'
print
(
"mel_path:"
,
mel_path
)
np
.
save
(
mel_path
,
logmel
)
durations
=
[
e
-
s
for
e
,
s
in
zip
(
mfa_end
,
mfa_start
)]
datum
=
{
"utt_id"
:
utt_id
,
"spk_id"
:
0
,
"text"
:
text
,
"text_lengths"
:
len
(
text
),
"speech_lengths"
:
115
,
"durations"
:
durations
,
"speech"
:
mel_path
,
"align_start"
:
mfa_start
,
"align_end"
:
mfa_end
,
"span_bdy"
:
span_bdy
}
batch
=
collate_fn
([
datum
])
print
(
"batch:"
,
batch
)
return
batch
,
old_span_bdy
,
new_span_bdy
def
decode_with_model
(
mlm_model
:
nn
.
Layer
,
collate_fn
,
wav_path
:
str
,
old_str
:
str
=
''
,
new_str
:
str
=
''
,
source_lang
:
str
=
'en'
,
target_lang
:
str
=
'en'
,
use_teacher_forcing
:
bool
=
False
,
duration_adjust
:
bool
=
True
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
,
token_list
:
List
[
str
]
=
[]):
batch
,
old_span_bdy
,
new_span_bdy
=
prep_feats
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
duration_adjust
=
duration_adjust
,
fs
=
fs
,
n_shift
=
n_shift
,
token_list
=
token_list
)
feats
=
collate_fn
(
batch
)[
1
]
if
'text_masked_pos'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_pos'
)
output
=
mlm_model
.
inference
(
text
=
feats
[
'text'
],
speech
=
feats
[
'speech'
],
masked_pos
=
feats
[
'masked_pos'
],
speech_mask
=
feats
[
'speech_mask'
],
text_mask
=
feats
[
'text_mask'
],
speech_seg_pos
=
feats
[
'speech_seg_pos'
],
text_seg_pos
=
feats
[
'text_seg_pos'
],
span_bdy
=
new_span_bdy
,
use_teacher_forcing
=
use_teacher_forcing
)
# 拼接音频
output_feat
=
paddle
.
concat
(
x
=
output
,
axis
=
0
)
wav_org
,
_
=
librosa
.
load
(
wav_path
,
sr
=
fs
)
return
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
if
__name__
==
'__main__'
:
fs
=
24000
n_shift
=
300
wav_path
=
"exp/p243_313.wav"
old_str
=
"For that reason cover should not be given."
# for edit
# new_str = "for that reason cover is impossible to be given."
# for synthesize
append_str
=
"do you love me i love you so much"
new_str
=
old_str
+
append_str
'''
outs = prep_feats_with_dur(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
fs=fs,
n_shift=n_shift)
new_wav = outs['new_wav']
new_phns = outs['new_phns']
new_mfa_start = outs['new_mfa_start']
new_mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy']
print("---------------------------------")
print("new_wav:", new_wav)
print("new_phns:", new_phns)
print("new_mfa_start:", new_mfa_start)
print("new_mfa_end:", new_mfa_end)
print("old_span_bdy:", old_span_bdy)
print("new_span_bdy:", new_span_bdy)
print("---------------------------------")
'''
erniesat_config
=
"/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
with
open
(
erniesat_config
)
as
f
:
erniesat_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
erniesat_stat
=
"/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
# Extractor
mel_extractor
=
LogMelFBank
(
sr
=
erniesat_config
.
fs
,
n_fft
=
erniesat_config
.
n_fft
,
hop_length
=
erniesat_config
.
n_shift
,
win_length
=
erniesat_config
.
win_length
,
window
=
erniesat_config
.
window
,
n_mels
=
erniesat_config
.
n_mels
,
fmin
=
erniesat_config
.
fmin
,
fmax
=
erniesat_config
.
fmax
)
collate_fn
=
build_erniesat_collate_fn
(
mlm_prob
=
erniesat_config
.
mlm_prob
,
mean_phn_span
=
erniesat_config
.
mean_phn_span
,
seg_emb
=
erniesat_config
.
model
[
'enc_input_layer'
]
==
'sega_mlm'
,
text_masking
=
False
)
phones_dict
=
'/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
vocab_phones
=
{}
with
open
(
phones_dict
,
'rt'
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
phn
,
id
in
phn_id
:
vocab_phones
[
phn
]
=
int
(
id
)
prep_feats
(
wav_path
=
wav_path
,
old_str
=
old_str
,
new_str
=
new_str
,
fs
=
fs
,
n_shift
=
n_shift
)
paddlespeech/t2s/exps/ernie_sat/utils.py
0 → 100644
浏览文件 @
9d4161ce
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
List
from
typing
import
Union
import
os
import
numpy
as
np
import
paddle
import
yaml
from
yacs.config
import
CfgNode
import
hashlib
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
'/'
)[
-
1
]
def
str2md5
(
string
):
md5_val
=
hashlib
.
md5
(
string
.
encode
(
'utf8'
)).
hexdigest
()
return
md5_val
def
get_tmp_name
(
text
:
str
):
return
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
+
'_'
+
str2md5
(
text
)
def
get_dict
(
dictfile
:
str
):
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
for
line
in
fid
:
line_lst
=
line
.
split
()
word
,
phn_lst
=
line_lst
[
0
],
line
.
split
()[
1
:]
if
word
not
in
word2phns_dict
.
keys
():
word2phns_dict
[
word
]
=
' '
.
join
(
phn_lst
)
return
word2phns_dict
# 获取需要被 mask 的 mel 帧的范围
def
get_span_bdy
(
mfa_start
:
List
[
float
],
mfa_end
:
List
[
float
],
span_to_repl
:
List
[
List
[
int
]]):
if
span_to_repl
[
0
]
>=
len
(
mfa_start
):
span_bdy
=
[
mfa_end
[
-
1
],
mfa_end
[
-
1
]]
else
:
span_bdy
=
[
mfa_start
[
span_to_repl
[
0
]],
mfa_end
[
span_to_repl
[
1
]
-
1
]]
return
span_bdy
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
get_dur_adj_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
factor_list
=
[]
for
orig
,
pred
,
phn
in
zip
(
orig_dur
,
pred_dur
,
phns
):
if
pred
==
0
or
phn
==
'sp'
:
continue
else
:
factor_list
.
append
(
orig
/
pred
)
factor_list
=
np
.
array
(
factor_list
)
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
return
1
length
=
2
avg
=
np
.
average
(
factor_list
[
length
:
-
length
])
return
avg
def
read_2col_text
(
path
:
Union
[
Path
,
str
])
->
Dict
[
str
,
str
]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data
=
{}
with
Path
(
path
).
open
(
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
linenum
,
line
in
enumerate
(
f
,
1
):
sps
=
line
.
rstrip
().
split
(
maxsplit
=
1
)
if
len
(
sps
)
==
1
:
k
,
v
=
sps
[
0
],
""
else
:
k
,
v
=
sps
if
k
in
data
:
raise
RuntimeError
(
f
"
{
k
}
is duplicated (
{
path
}
:
{
linenum
}
)"
)
data
[
k
]
=
v
return
data
def
load_num_sequence_text
(
path
:
Union
[
Path
,
str
],
loader_type
:
str
=
"csv_int"
)
->
Dict
[
str
,
List
[
Union
[
float
,
int
]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
if
loader_type
==
"text_int"
:
delimiter
=
" "
dtype
=
int
elif
loader_type
==
"text_float"
:
delimiter
=
" "
dtype
=
float
elif
loader_type
==
"csv_int"
:
delimiter
=
","
dtype
=
int
elif
loader_type
==
"csv_float"
:
delimiter
=
","
dtype
=
float
else
:
raise
ValueError
(
f
"Not supported loader_type=
{
loader_type
}
"
)
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d
=
read_2column_text
(
path
)
# Using for-loop instead of dict-comprehension for debuggability
retval
=
{}
for
k
,
v
in
d
.
items
():
try
:
retval
[
k
]
=
[
dtype
(
i
)
for
i
in
v
.
split
(
delimiter
)]
except
TypeError
:
print
(
f
'Error happened with path="
{
path
}
", id="
{
k
}
", value="
{
v
}
"'
)
raise
return
retval
def
is_chinese
(
ch
):
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
return
True
else
:
return
False
def
get_voc_out
(
mel
):
# vocoder
args
=
parse_args
()
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
voc_stat
=
args
.
voc_stat
)
with
paddle
.
no_grad
():
wav
=
voc_inference
(
mel
)
return
np
.
squeeze
(
wav
)
def
eval_durs
(
phns
,
target_lang
:
str
=
'zh'
,
fs
:
int
=
24000
,
n_shift
:
int
=
300
):
if
target_lang
==
'en'
:
am
=
"fastspeech2_ljspeech"
am_config
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
am_ckpt
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
am_stat
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
phones_dict
=
"download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif
target_lang
==
'zh'
:
am
=
"fastspeech2_csmsc"
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# Init body.
with
open
(
am_config
)
as
f
:
am_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
am_inference
,
am
=
get_am_inference
(
am
=
am
,
am_config
=
am_config
,
am_ckpt
=
am_ckpt
,
am_stat
=
am_stat
,
phones_dict
=
phones_dict
,
return_am
=
True
)
vocab_phones
=
{}
with
open
(
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
tone
,
id
in
phn_id
:
vocab_phones
[
tone
]
=
int
(
id
)
vocab_size
=
len
(
vocab_phones
)
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids
,
np
.
int64
))
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids
)
d_outs
=
d_outs
.
tolist
()
return
d_outs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录