Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
445e3040
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
445e3040
编写于
6月 15, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix vocoder inference
上级
e522009d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
45 addition
and
65 deletion
+45
-65
ernie-sat/inference.py
ernie-sat/inference.py
+34
-53
ernie-sat/utils.py
ernie-sat/utils.py
+11
-12
未找到文件。
ernie-sat/inference.py
浏览文件 @
445e3040
...
...
@@ -11,11 +11,7 @@ import paddle
import
soundfile
as
sf
import
torch
from
paddle
import
nn
from
sedit_arg_parser
import
parse_args
from
utils
import
build_vocoder_from_file
from
utils
import
evaluate_durations
from
utils
import
get_voc_out
from
utils
import
is_chinese
from
ParallelWaveGAN.parallel_wavegan.utils.utils
import
download_pretrained_model
from
align
import
alignment
from
align
import
alignment_zh
...
...
@@ -25,20 +21,24 @@ from collect_fn import build_collate_fn
from
mlm
import
build_model_from_file
from
read_text
import
load_num_sequence_text
from
read_text
import
read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from
sedit_arg_parser
import
parse_args
from
utils
import
build_vocoder_from_file
from
utils
import
eval_durs
from
utils
import
get_voc_out
from
utils
import
is_chinese
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
def
plot_mel_and_vocode
_wav
(
wav_path
:
str
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_pt_vocoder
:
bool
=
False
,
non_autoreg
:
bool
=
True
):
def
get
_wav
(
wav_path
:
str
,
source_lang
:
str
=
'english'
,
target_lang
:
str
=
'english'
,
model_name
:
str
=
"paddle_checkpoint_en"
,
old_str
:
str
=
""
,
new_str
:
str
=
""
,
use_pt_vocoder
:
bool
=
False
,
non_autoreg
:
bool
=
True
):
wav_org
,
output_feat
,
old_span_bdy
,
new_span_bdy
,
fs
,
hop_length
=
get_mlm_output
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
...
...
@@ -50,41 +50,23 @@ def plot_mel_and_vocode_wav(wav_path: str,
masked_feat
=
output_feat
[
new_span_bdy
[
0
]:
new_span_bdy
[
1
]]
if
target_lang
==
'english'
:
if
use_pt_vocoder
:
output_feat
=
output_feat
.
cpu
().
numpy
()
output_feat
=
torch
.
tensor
(
output_feat
,
dtype
=
torch
.
float
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
replaced_wav
=
vocoder
(
output_feat
).
cpu
().
numpy
()
else
:
replaced_wav
=
get_voc_out
(
output_feat
)
if
target_lang
==
'english'
and
use_pt_vocoder
:
masked_feat
=
masked_feat
.
cpu
().
numpy
()
masked_feat
=
torch
.
tensor
(
masked_feat
,
dtype
=
torch
.
float
)
vocoder
=
load_vocoder
(
'vctk_parallel_wavegan.v1.long'
)
alt_wav
=
vocoder
(
masked_feat
).
cpu
().
numpy
()
el
if
target_lang
==
'chinese'
:
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
)
el
se
:
alt_wav
=
get_voc_out
(
masked_feat
)
old_time_bdy
=
[
hop_length
*
x
for
x
in
old_span_bdy
]
new_time_bdy
=
[
hop_length
*
x
for
x
in
new_span_bdy
]
if
target_lang
==
'english'
:
wav_org_replaced_paddle_voc
=
np
.
concatenate
([
wav_org
[:
old_time_bdy
[
0
]],
replaced_wav
[
new_time_bdy
[
0
]:
new_time_bdy
[
1
]],
wav_org
[
old_time_bdy
[
1
]:]
])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_paddle_voc
}
wav_replaced
=
np
.
concatenate
(
[
wav_org
[:
old_time_bdy
[
0
]],
alt_wav
,
wav_org
[
old_time_bdy
[
1
]:]])
elif
target_lang
==
'chinese'
:
wav_org_replaced_only_mask_fst2_voc
=
np
.
concatenate
([
wav_org
[:
old_time_bdy
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[
old_time_bdy
[
1
]:]
])
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_only_mask_fst2_voc
,
}
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_replaced
}
return
data_dict
,
old_span_bdy
return
data_dict
def
load_vocoder
(
vocoder_tag
:
str
=
"vctk_parallel_wavegan.v1.long"
):
...
...
@@ -323,9 +305,9 @@ def get_phns_and_spans(wav_path: str,
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def
duration_adjust
_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
def
get_dur_adj
_factor
(
orig_dur
:
List
[
int
],
pred_dur
:
List
[
int
],
phns
:
List
[
str
]):
length
=
0
factor_list
=
[]
for
orig
,
pred
,
phn
in
zip
(
orig_dur
,
pred_dur
,
phns
):
...
...
@@ -376,7 +358,7 @@ def prep_feats_with_dur(wav_path: str,
new_phns
=
new_phns
+
[
'sp'
]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
old_durs
=
eval
uate_duration
s
(
old_phns
,
target_lang
=
source_lang
)
old_durs
=
eval
_dur
s
(
old_phns
,
target_lang
=
source_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
...
...
@@ -385,11 +367,11 @@ def prep_feats_with_dur(wav_path: str,
if
'[MASK]'
in
new_str
:
new_phns
=
old_phns
span_to_add
=
span_to_repl
d_factor_left
=
duration_adjust
_factor
(
d_factor_left
=
get_dur_adj
_factor
(
orig_dur
=
orig_old_durs
[:
span_to_repl
[
0
]],
pred_dur
=
old_durs
[:
span_to_repl
[
0
]],
phns
=
old_phns
[:
span_to_repl
[
0
]])
d_factor_right
=
duration_adjust
_factor
(
d_factor_right
=
get_dur_adj
_factor
(
orig_dur
=
orig_old_durs
[
span_to_repl
[
1
]:],
pred_dur
=
old_durs
[
span_to_repl
[
1
]:],
phns
=
old_phns
[
span_to_repl
[
1
]:])
...
...
@@ -397,15 +379,14 @@ def prep_feats_with_dur(wav_path: str,
new_durs_adjusted
=
[
d_factor
*
i
for
i
in
old_durs
]
else
:
if
duration_adjust
:
d_factor
=
duration_adjust
_factor
(
d_factor
=
get_dur_adj
_factor
(
orig_dur
=
orig_old_durs
,
pred_dur
=
old_durs
,
phns
=
old_phns
)
print
(
"d_factor:"
,
d_factor
)
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
if
target_lang
==
"english"
or
target_lang
==
"chinese"
:
new_durs
=
eval
uate_duration
s
(
new_phns
,
target_lang
=
target_lang
)
new_durs
=
eval
_dur
s
(
new_phns
,
target_lang
=
target_lang
)
else
:
assert
target_lang
==
"chinese"
or
target_lang
==
"english"
,
\
"calculate duration_predict is not support for this language..."
...
...
@@ -616,7 +597,7 @@ def evaluate(uid: str,
print
(
'new_str is '
,
new_str
)
results_dict
,
old_span
=
plot_mel_and_vocode
_wav
(
results_dict
=
get
_wav
(
source_lang
=
source_lang
,
target_lang
=
target_lang
,
model_name
=
model_name
,
...
...
ernie-sat/utils.py
浏览文件 @
445e3040
import
os
from
typing
import
List
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -32,6 +31,7 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference"
,
}
def
is_chinese
(
ch
):
if
u
'
\u4e00
'
<=
ch
<=
u
'
\u9fff
'
:
return
True
...
...
@@ -61,7 +61,7 @@ def get_voc_out(mel):
# print("current vocoder: ", args.voc)
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
voc_inference
=
voc_inference
=
get_voc_inference
(
voc_inference
=
get_voc_inference
(
voc
=
args
.
voc
,
voc_config
=
voc_config
,
voc_ckpt
=
args
.
voc_ckpt
,
...
...
@@ -164,7 +164,7 @@ def get_voc_inference(
return
voc_inference
def
eval
uate_duration
s
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
def
eval
_dur
s
(
phns
,
target_lang
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
args
=
parse_args
()
if
target_lang
==
'english'
:
...
...
@@ -176,10 +176,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
elif
target_lang
==
'chinese'
:
args
.
am
=
"fastspeech2_csmsc"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_config
=
"download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args
.
am_ckpt
=
"download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args
.
am_stat
=
"download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
args
.
phones_dict
=
"download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
...
...
@@ -211,11 +211,10 @@ def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
phonemes
=
[
phn
if
phn
in
vocab_phones
else
"sp"
for
phn
in
phns
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids_new
=
phone_ids
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
phone_ids
.
append
(
vocab_size
-
1
)
phone_ids
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids
,
np
.
int64
))
_
,
d_outs
,
_
,
_
=
am
.
inference
(
phone_ids
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
ph
oneme_duration
s_new
=
pre_d_outs
*
hop_length
/
fs
ph
oneme_durations_new
=
phoneme_duration
s_new
.
tolist
()[:
-
1
]
return
ph
oneme_duration
s_new
ph
u_dur
s_new
=
pre_d_outs
*
hop_length
/
fs
ph
u_durs_new
=
phu_dur
s_new
.
tolist
()[:
-
1
]
return
ph
u_dur
s_new
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录