Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
c2f4709e
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c2f4709e
编写于
5月 20, 2022
作者:
P
pfZhu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify code
上级
75dae192
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
503 addition
and
42 deletion
+503
-42
ernie-sat/.DS_Store
ernie-sat/.DS_Store
+0
-0
ernie-sat/README.md
ernie-sat/README.md
+4
-4
ernie-sat/inference.py
ernie-sat/inference.py
+475
-0
ernie-sat/prompt/dev/mfa_end
ernie-sat/prompt/dev/mfa_end
+1
-1
ernie-sat/prompt/dev/mfa_start
ernie-sat/prompt/dev/mfa_start
+1
-1
ernie-sat/prompt/dev/mfa_text
ernie-sat/prompt/dev/mfa_text
+1
-1
ernie-sat/prompt/dev/mfa_wav.scp
ernie-sat/prompt/dev/mfa_wav.scp
+1
-1
ernie-sat/prompt/dev/text
ernie-sat/prompt/dev/text
+1
-1
ernie-sat/prompt/dev/wav.scp
ernie-sat/prompt/dev/wav.scp
+1
-1
ernie-sat/run_clone_en_to_zh.sh
ernie-sat/run_clone_en_to_zh.sh
+5
-5
ernie-sat/run_gen_en.sh
ernie-sat/run_gen_en.sh
+3
-22
ernie-sat/run_sedit_en.sh
ernie-sat/run_sedit_en.sh
+4
-2
ernie-sat/tmp/tmp_pkl.Prompt_003_new
ernie-sat/tmp/tmp_pkl.Prompt_003_new
+0
-0
ernie-sat/tmp/tmp_pkl.p243_new
ernie-sat/tmp/tmp_pkl.p243_new
+0
-0
ernie-sat/tmp/tmp_pkl.p299_096
ernie-sat/tmp/tmp_pkl.p299_096
+0
-0
ernie-sat/utils.py
ernie-sat/utils.py
+6
-3
ernie-sat/wavs/ori.wav
ernie-sat/wavs/ori.wav
+0
-0
ernie-sat/wavs/pred.wav
ernie-sat/wavs/pred.wav
+0
-0
ernie-sat/wavs/pred_en_edit_paddle_voc.wav
ernie-sat/wavs/pred_en_edit_paddle_voc.wav
+0
-0
ernie-sat/wavs/pred_zh.wav
ernie-sat/wavs/pred_zh.wav
+0
-0
ernie-sat/wavs/pred_zh_fst2_voc.wav
ernie-sat/wavs/pred_zh_fst2_voc.wav
+0
-0
ernie-sat/wavs/task_cross_lingual_pred.wav
ernie-sat/wavs/task_cross_lingual_pred.wav
+0
-0
ernie-sat/wavs/task_edit_pred.wav
ernie-sat/wavs/task_edit_pred.wav
+0
-0
ernie-sat/wavs/task_synthesize_pred.wav
ernie-sat/wavs/task_synthesize_pred.wav
+0
-0
未找到文件。
ernie-sat/.DS_Store
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
ernie-sat/README.md
浏览文件 @
c2f4709e
...
...
@@ -11,7 +11,7 @@ ERNIE-SAT中我们提出了两项创新:
### 1.安装飞桨
我们
的代码基于 Paddle(version>=2.0)
本项目
的代码基于 Paddle(version>=2.0)
### 2.预训练模型
...
...
@@ -23,7 +23,7 @@ ERNIE-SAT中我们提出了两项创新:
### 3.下载
1.
我们
使用parallel wavegan作为声码器(vocoder):
1.
本项目
使用parallel wavegan作为声码器(vocoder):
-
[
pwg_aishell3_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip
)
创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压
...
...
@@ -34,7 +34,7 @@ cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2.
我们
使用
[
FastSpeech2
](
https://arxiv.org/abs/2006.04558
)
作为音素(phoneme)的持续时间预测器:
2.
本项目
使用
[
FastSpeech2
](
https://arxiv.org/abs/2006.04558
)
作为音素(phoneme)的持续时间预测器:
-
[
fastspeech2_conformer_baker_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip
)
中文场景下使用
-
[
fastspeech2_nosil_ljspeech_ckpt_0.5.zip
](
https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip
)
英文场景下使用
...
...
@@ -48,7 +48,7 @@ unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
### 4.推理
我们目
前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
本项目当
前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前采用的声码器版本与
[
模型训练时版本
](
https://github.com/kan-bayashi/ParallelWaveGAN
)
在英文上存在差异,您可使用模型训练时版本作为您的声码器,模型将在后续更新中升级。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
...
...
ernie-sat/
sedit_inference_0520
.py
→
ernie-sat/
inference
.py
浏览文件 @
c2f4709e
#!/usr/bin/env python3
"""Script to run the inference of text-to-speeech model."""
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"3"
from
parallel_wavegan.utils
import
download_pretrained_model
from
pathlib
import
Path
import
paddle
import
soundfile
import
os
import
math
import
string
import
numpy
as
np
from
espnet2.tasks.mlm
import
MLMTask
from
read_text
import
read_2column_text
,
load_num_sequence_text
from
util
import
sentence2phns
,
get_voc_out
,
evaluate_durations
from
util
s
import
sentence2phns
,
get_voc_out
,
evaluate_durations
import
librosa
import
random
from
ipywidgets
import
widgets
import
IPython.display
as
ipd
import
soundfile
as
sf
import
sys
import
pickle
...
...
@@ -37,19 +27,10 @@ from typing import Union
from
paddlespeech.t2s.datasets.get_feats
import
LogMelFBank
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
duration_path_dict
=
{
"ljspeech"
:
"/mnt/home/v_baihe/projects/espnet/egs2/ljspeech/tts1/exp/kan-bayashi/ljspeech_tts_train_conformer_fastspeech2_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth"
,
"vctk"
:
"/mnt/home/v_baihe/projects/espnet/egs2/vctk/tts1/exp/kan-bayashi/vctk_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth"
,
# "ljspeech":"/home/mnt2/zz/workspace/work/espnet_richard_infer/egs2/ljspeech/tts1/exp/kan-bayashi/ljspeech_tts_train_conformer_fastspeech2_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth",
# "vctk": "/home/mnt2/zz/workspace/work/espnet_richard_infer/egs2/vctk/tts1/exp/kan-bayashi/vctk_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss.ave/train.loss.ave_5best.pth",
"vctk_unseen"
:
"/mnt/home/v_baihe/projects/espnet/egs2/vctk/tts1/exp/tts_train_fs2_raw_phn_tacotron_g2p_en_no_space/train.loss.ave_5best.pth"
,
"libritts"
:
"/mnt/home/v_baihe/projects/espnet/egs2/libritts/tts1/exp/kan-bayashi/libritts_tts_train_gst+xvector_conformer_fastspeech2_transformer_teacher_raw_phn_tacotron_g2p_en_no_space_train.loss/train.loss.ave_5best.pth"
}
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
def
plot_mel_and_vocode_wav
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
vocoder
,
duration_preditor_path
,
sid
=
None
,
non_autoreg
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
=
get_mlm_output
(
uid
,
...
...
@@ -66,47 +47,47 @@ def plot_mel_and_vocode_wav(uid, prefix, clone_uid, clone_prefix, source_languag
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
)
masked_feat
=
output_feat
[
new_span_boundary
[
0
]:
new_span_boundary
[
1
]].
detach
().
float
().
cpu
().
numpy
()
if
target_language
==
'english'
:
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav_paddle_voc
=
get_voc_out
(
output_feat_np
,
target_language
)
replaced_wav
=
replaced_wav_paddle_voc
elif
target_language
==
'chinese'
:
assert
old_span_boundary
[
1
]
==
new_span_boundary
[
0
],
"old_span_boundary[1] is not same with new_span_boundary[0]."
output_feat_np
=
output_feat
.
detach
().
float
().
cpu
().
numpy
()
replaced_wav
=
get_voc_out
(
output_feat_np
)
replaced_wav_only_mask
=
get_voc_out
(
masked_feat
)
replaced_wav_only_mask_fst2_voc
=
get_voc_out
(
masked_feat
,
target_language
)
old_time_boundary
=
[
hop_length
*
x
for
x
in
old_span_boundary
]
new_time_boundary
=
[
hop_length
*
x
for
x
in
new_span_boundary
]
wav_org_replaced
=
np
.
concatenate
([
wav_org
[:
old_time_boundary
[
0
]],
replaced_wav
[
new_time_boundary
[
0
]:
new_time_boundary
[
1
]],
wav_org
[
old_time_boundary
[
1
]:]])
if
target_language
==
'english'
:
# new add to test paddle vocoder
wav_org_replaced_paddle_voc
=
np
.
concatenate
([
wav_org
[:
old_time_boundary
[
0
]],
replaced_wav_paddle_voc
[
new_time_boundary
[
0
]:
new_time_boundary
[
1
]],
wav_org
[
old_time_boundary
[
1
]:]])
data_dict
=
{
"origin"
:
wav_org
,
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_paddle_voc
}
elif
target_language
==
'chinese'
:
wav_org_replaced_only_mask
=
np
.
concatenate
([
wav_org
[:
old_time_boundary
[
0
]],
replaced_wav_only_mask
,
wav_org
[
old_time_boundary
[
1
]:]])
wav_org_replaced_only_mask_fst2_voc
=
np
.
concatenate
([
wav_org
[:
old_time_boundary
[
0
]],
replaced_wav_only_mask_fst2_voc
,
wav_org
[
old_time_boundary
[
1
]:]])
data_dict
=
{
"origin"
:
wav_org
,
data_dict
=
{
"origin"
:
wav_org
,
"output"
:
wav_org_replaced_only_mask_fst2_voc
,}
return
data_dict
,
old_span_boundary
def
load_vocoder
(
vocoder_tag
=
"parallel_wavegan/libritts_parallel_wavegan.v1"
):
vocoder_tag
=
vocoder_tag
.
replace
(
"parallel_wavegan/"
,
""
)
vocoder_file
=
download_pretrained_model
(
vocoder_tag
)
vocoder_config
=
Path
(
vocoder_file
).
parent
/
"config.yml"
vocoder
=
TTSTask
.
build_vocoder_from_file
(
vocoder_config
,
vocoder_file
,
None
,
'cpu'
)
return
vocoder
def
load_model
(
model_name
):
config_path
=
'./pretrained_model/{}/config.yaml'
.
format
(
model_name
)
model_path
=
'./pretrained_model/{}/model.pdparams'
.
format
(
model_name
)
...
...
@@ -132,12 +113,6 @@ def get_align_data(uid,prefix):
return
mfa_text
,
mfa_start
,
mfa_end
,
mfa_wav_path
def
get_fs2_model
(
model_name
):
model
,
config
=
TTSTask
.
build_model_from_file
(
model_file
=
model_name
)
processor
=
TTSTask
.
build_preprocess_fn
(
config
,
train
=
False
)
return
model
,
processor
def
get_masked_mel_boundary
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_tobe_replaced
):
align_start
=
paddle
.
to_tensor
(
mfa_start
).
unsqueeze
(
0
)
align_end
=
paddle
.
to_tensor
(
mfa_end
).
unsqueeze
(
0
)
...
...
@@ -150,6 +125,16 @@ def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, span_tobe_replac
return
span_boundary
def
gen_phns
(
zh_mapping
,
phns
):
new_phns
=
[]
for
x
in
phns
:
if
x
in
zh_mapping
.
keys
():
new_phns
.
extend
(
zh_mapping
[
x
].
split
(
" "
))
else
:
new_phns
.
extend
([
'<unk>'
])
return
new_phns
def
get_mapping
(
phn_mapping
=
"./phn_mapping.txt"
):
zh_mapping
=
{}
with
open
(
phn_mapping
,
"r"
)
as
f
:
...
...
@@ -160,15 +145,6 @@ def get_mapping(phn_mapping="./phn_mapping.txt"):
return
zh_mapping
def
gen_phns
(
zh_mapping
,
phns
):
new_phns
=
[]
for
x
in
phns
:
if
x
in
zh_mapping
.
keys
():
new_phns
.
extend
(
zh_mapping
[
x
].
split
(
" "
))
else
:
new_phns
.
extend
([
'<unk>'
])
return
new_phns
def
get_phns_and_spans_paddle
(
uid
,
prefix
,
old_str
,
new_str
,
source_language
,
target_language
):
zh_mapping
=
get_mapping
()
old_str
=
old_str
.
strip
()
...
...
@@ -205,6 +181,7 @@ def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, ta
assert
target_language
==
"chinese"
or
target_language
==
"english"
,
"cloning is not support for this language, please check it."
else
:
if
source_language
==
target_language
and
target_language
==
"english"
:
new_phns_origin
=
old_phns
new_phns_append
,
_
=
sentence2phns
(
new_str_append
,
"en"
)
...
...
@@ -222,11 +199,12 @@ def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, ta
span_tobe_replaced
=
[
len
(
old_phns
),
len
(
old_phns
)]
span_tobe_added
=
[
len
(
old_phns
),
len
(
new_phns
)]
else
:
if
source_language
==
target_language
and
target_language
==
"english"
:
new_phns
,
_
=
sentence2phns
(
new_str
,
"en"
)
# 纯中文
elif
source_language
==
target_language
and
target_language
==
"chinese"
:
new_phns
,
_
=
sentence2phns
(
new_str
,
"zh"
)
new_phns
=
gen_phns
(
zh_mapping
,
new_phns
)
...
...
@@ -247,7 +225,6 @@ def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, ta
left_index
=
0
sp_count
=
0
# find the left different index
for
idx
,
phn
in
enumerate
(
old_phns
):
if
phn
==
"sp"
:
sp_count
+=
1
...
...
@@ -292,12 +269,12 @@ def get_phns_and_spans_paddle(uid, prefix, old_str, new_str, source_language, ta
break
new_phns
=
new_phns_left
+
new_phns_middle
+
new_phns_right
return
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_tobe_replaced
,
span_tobe_added
def
duration_adjust_factor
(
original_dur
,
pred_dur
,
phns
):
length
=
0
accumulate
=
0
...
...
@@ -311,17 +288,19 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
factor_list
.
sort
()
if
len
(
factor_list
)
<
5
:
return
1
length
=
2
return
np
.
average
(
factor_list
[
length
:
-
length
])
def
prepare_features_with_duration
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
old_str
,
new_str
,
wav_path
,
duration_preditor_path
,
sid
=
None
,
mask_reconstruct
=
False
,
duration_adjust
=
True
,
start_end_sp
=
False
,
train_args
=
None
):
wav_org
,
rate
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
fs
=
train_args
.
feats_extract_conf
[
'fs'
]
hop_length
=
train_args
.
feats_extract_conf
[
'hop_length'
]
mfa_start
,
mfa_end
,
old_phns
,
new_phns
,
span_tobe_replaced
,
span_tobe_added
=
get_phns_and_spans_paddle
(
uid
,
prefix
,
old_str
,
new_str
,
source_language
,
target_language
)
if
start_end_sp
:
if
new_phns
[
-
1
]
!=
'sp'
:
new_phns
=
new_phns
+
[
'sp'
]
...
...
@@ -331,8 +310,10 @@ def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_
old_durations
=
evaluate_durations
(
old_phns
,
target_language
=
target_language
)
elif
target_language
==
"chinese"
:
if
source_language
==
"english"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_language
=
source_language
)
elif
source_language
==
"chinese"
:
old_durations
=
evaluate_durations
(
old_phns
,
target_language
=
source_language
)
...
...
@@ -353,8 +334,7 @@ def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_
if
duration_adjust
:
d_factor
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
d_factor_paddle
=
duration_adjust_factor
(
original_old_durations
,
old_durations
,
old_phns
)
if
target_language
==
"chinese"
:
d_factor
=
d_factor
*
1.35
d_factor
=
d_factor
*
1.25
else
:
d_factor
=
1
...
...
@@ -401,7 +381,7 @@ def prepare_features_with_duration(uid, prefix, clone_uid, clone_prefix, source_
# 4. get old and new mel span to be mask
old_span_boundary
=
get_masked_mel_boundary
(
mfa_start
,
mfa_end
,
fs
,
hop_length
,
span_tobe_replaced
)
# [92, 92]
new_span_boundary
=
get_masked_mel_boundary
(
new_mfa_start
,
new_mfa_end
,
fs
,
hop_length
,
span_tobe_added
)
# [92, 174]
return
new_wav_org
,
new_phns
,
new_mfa_start
,
new_mfa_end
,
old_span_boundary
,
new_span_boundary
...
...
@@ -414,7 +394,7 @@ mask_reconstruct=False, train_args=None):
align_end
=
np
.
array
(
mfa_end
)
token_to_id
=
{
item
:
i
for
i
,
item
in
enumerate
(
train_args
.
token_list
)}
text
=
np
.
array
(
list
(
map
(
lambda
x
:
token_to_id
.
get
(
x
,
token_to_id
[
'<unk>'
]),
phns_list
)))
print
(
'unk id is'
,
token_to_id
[
'<unk>'
])
#
print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_boundary
=
np
.
array
(
new_span_boundary
)
batch
=
[(
'1'
,
{
"speech"
:
speech
,
"align_start"
:
align_start
,
"align_end"
:
align_end
,
"text"
:
text
,
"span_boundary"
:
span_boundary
})]
...
...
@@ -422,17 +402,19 @@ mask_reconstruct=False, train_args=None):
return
batch
,
old_span_boundary
,
new_span_boundary
def
decode_with_model
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
collate_fn
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
None
,
decoder
=
False
,
use_teacher_forcing
=
False
,
duration_adjust
=
True
,
start_end_sp
=
False
,
train_args
=
None
):
# fs, hop_length = mlm_model.feats_extract.fs, mlm_model.feats_extract.hop_length
fs
,
hop_length
=
train_args
.
feats_extract_conf
[
'fs'
],
train_args
.
feats_extract_conf
[
'hop_length'
]
batch
,
old_span_boundary
,
new_span_boundary
=
prepare_features
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
feats
=
pickle
.
load
(
open
(
'tmp/tmp_pkl.'
+
str
(
uid
),
'rb'
))
tmp
=
feats
[
'speech'
][
0
]
# print('feats end')
# wav_len * 80
# set_all_random_seed(9999)
if
'text_masked_position'
in
feats
.
keys
():
feats
.
pop
(
'text_masked_position'
)
for
k
,
v
in
feats
.
items
():
feats
[
k
]
=
paddle
.
to_tensor
(
v
)
rtn
=
mlm_model
.
inference
(
**
feats
,
span_boundary
=
new_span_boundary
,
use_teacher_forcing
=
use_teacher_forcing
)
...
...
@@ -446,612 +428,24 @@ def decode_with_model(uid, prefix, clone_uid, clone_prefix, source_language, tar
else
:
output_feat
=
paddle
.
concat
([
output
[
0
].
squeeze
(
0
)]
+
output
[
1
:
-
1
]
+
[
output
[
-
1
].
squeeze
(
0
)],
axis
=
0
).
cpu
()
# wav_org, rate = soundfile.read(
# wav_path, always_2d=False)
wav_org
,
rate
=
librosa
.
load
(
wav_path
,
sr
=
train_args
.
feats_extract_conf
[
'fs'
])
origin_speech
=
paddle
.
to_tensor
(
np
.
array
(
wav_org
,
dtype
=
np
.
float32
)).
unsqueeze
(
0
)
speech_lengths
=
paddle
.
to_tensor
(
len
(
wav_org
)).
unsqueeze
(
0
)
# input_feat, feats_lengths = mlm_model.feats_extract(origin_speech, speech_lengths)
# return wav_org, input_feat.squeeze(), output_feat, old_span_boundary, new_span_boundary, fs, hop_length
return
wav_org
,
None
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
class
MLMCollateFn
:
"""Functor class of common_collate_fn()"""
def
__init__
(
self
,
feats_extract
,
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=
-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
):
self
.
mlm_prob
=
mlm_prob
self
.
mean_phn_span
=
mean_phn_span
self
.
feats_extract
=
feats_extract
self
.
float_pad_value
=
float_pad_value
self
.
int_pad_value
=
int_pad_value
self
.
not_sequence
=
set
(
not_sequence
)
self
.
attention_window
=
attention_window
self
.
pad_speech
=
pad_speech
self
.
sega_emb
=
sega_emb
self
.
duration_collect
=
duration_collect
self
.
text_masking
=
text_masking
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
}
(float_pad_value=
{
self
.
float_pad_value
}
, "
f
"int_pad_value=
{
self
.
float_pad_value
}
)"
)
def
__call__
(
self
,
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
return
mlm_collate_fn
(
data
,
float_pad_value
=
self
.
float_pad_value
,
int_pad_value
=
self
.
int_pad_value
,
not_sequence
=
self
.
not_sequence
,
mlm_prob
=
self
.
mlm_prob
,
mean_phn_span
=
self
.
mean_phn_span
,
feats_extract
=
self
.
feats_extract
,
attention_window
=
self
.
attention_window
,
pad_speech
=
self
.
pad_speech
,
sega_emb
=
self
.
sega_emb
,
duration_collect
=
self
.
duration_collect
,
text_masking
=
self
.
text_masking
)
def
pad_list
(
xs
,
pad_value
):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch
=
len
(
xs
)
max_len
=
max
(
paddle
.
shape
(
x
)[
0
]
for
x
in
xs
)
pad
=
paddle
.
full
((
n_batch
,
max_len
),
pad_value
,
dtype
=
xs
[
0
].
dtype
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
paddle
.
shape
(
xs
[
i
])[
0
]]
=
xs
[
i
]
return
pad
def
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
):
round
=
max_len
%
attention_window
if
round
!=
0
:
max_tlen
+=
(
attention_window
-
round
)
n_batch
=
paddle
.
shape
(
text
)[
0
]
text_pad
=
paddle
.
zeros
((
n_batch
,
max_tlen
,
*
paddle
.
shape
(
text
[
0
])[
1
:]),
dtype
=
text
.
dtype
)
for
i
in
range
(
n_batch
):
text_pad
[
i
,
:
paddle
.
shape
(
text
[
i
])[
0
]]
=
text
[
i
]
else
:
text_pad
=
text
[:,
:
max_tlen
]
return
text_pad
,
max_tlen
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
print
(
'inputs are:'
,
lengths
,
xs
,
length_dim
)
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=paddle.uint8)
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=paddle.uint8)
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=paddle.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=paddle.uint8)
"""
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
if
not
isinstance
(
lengths
,
list
):
lengths
=
list
(
lengths
)
print
(
'lengths'
,
lengths
)
bs
=
int
(
len
(
lengths
))
if
xs
is
None
:
maxlen
=
int
(
max
(
lengths
))
else
:
maxlen
=
paddle
.
shape
(
xs
)[
length_dim
]
seq_range
=
paddle
.
arange
(
0
,
maxlen
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
paddle
.
expand
(
paddle
.
unsqueeze
(
seq_range
,
0
),
(
bs
,
maxlen
))
seq_length_expand
=
paddle
.
unsqueeze
(
paddle
.
to_tensor
(
lengths
),
-
1
)
print
(
'seq_length_expand'
,
paddle
.
shape
(
seq_length_expand
))
print
(
'seq_range_expand'
,
paddle
.
shape
(
seq_range_expand
))
mask
=
seq_range_expand
>=
seq_length_expand
if
xs
is
not
None
:
assert
paddle
.
shape
(
xs
)[
0
]
==
bs
,
(
paddle
.
shape
(
xs
)[
0
],
bs
)
if
length_dim
<
0
:
length_dim
=
len
(
paddle
.
shape
(
xs
))
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
len
(
paddle
.
shape
(
xs
)))
)
print
(
'0:'
,
paddle
.
shape
(
mask
))
print
(
'1:'
,
paddle
.
shape
(
mask
[
ind
]))
print
(
'2:'
,
paddle
.
shape
(
xs
))
mask
=
paddle
.
expand
(
mask
[
ind
],
paddle
.
shape
(
xs
))
return
mask
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=paddle.uint8)
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=paddle.uint8)
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=paddle.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=paddle.uint8)
"""
return
~
make_pad_mask
(
lengths
,
xs
,
length_dim
)
def
phones_masking
(
xs_pad
,
src_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
=
None
):
bz
,
sent_len
,
_
=
paddle
.
shape
(
xs_pad
)
mask_num_lower
=
math
.
ceil
(
sent_len
*
mlm_prob
)
masked_position
=
np
.
zeros
((
bz
,
sent_len
))
y_masks
=
None
# y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype)
# tril_masks = torch.tril(y_masks)
if
mlm_prob
==
1.0
:
masked_position
+=
1
# y_masks = tril_masks
elif
mean_phn_span
==
0
:
# only speech
length
=
sent_len
mean_phn_span
=
min
(
length
*
mlm_prob
//
3
,
50
)
masked_phn_indices
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_position
[:,
masked_phn_indices
]
=
1
else
:
for
idx
in
range
(
bz
):
if
span_boundary
is
not
None
:
for
s
,
e
in
zip
(
span_boundary
[
idx
][::
2
],
span_boundary
[
idx
][
1
::
2
]):
masked_position
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
else
:
length
=
align_start_lengths
[
idx
].
item
()
if
length
<
2
:
continue
masked_phn_indices
=
random_spans_noise_mask
(
length
,
mlm_prob
,
mean_phn_span
).
nonzero
()
masked_start
=
align_start
[
idx
][
masked_phn_indices
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_indices
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_position
[
idx
,
s
:
e
]
=
1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
non_eos_mask
=
np
.
array
(
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
]).
float
().
cpu
())
masked_position
=
masked_position
*
non_eos_mask
# y_masks = src_mask & y_masks.bool()
return
paddle
.
cast
(
paddle
.
to_tensor
(
masked_position
),
paddle
.
bool
),
y_masks
def
get_segment_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
sega_emb
):
bz
,
speech_len
,
_
=
speech_pad
.
size
()
text_segment_pos
=
paddle
.
zeros_like
(
text_pad
)
speech_segment_pos
=
paddle
.
zeros
((
bz
,
speech_len
),
dtype
=
text_pad
.
dtype
)
if
not
sega_emb
:
return
speech_segment_pos
,
text_segment_pos
for
idx
in
range
(
bz
):
align_length
=
align_start_lengths
[
idx
].
item
()
for
j
in
range
(
align_length
):
s
,
e
=
align_start
[
idx
][
j
].
item
(),
align_end
[
idx
][
j
].
item
()
speech_segment_pos
[
idx
][
s
:
e
]
=
j
+
1
text_segment_pos
[
idx
][
j
]
=
j
+
1
return
speech_segment_pos
,
text_segment_pos
def
mlm_collate_fn
(
data
:
Collection
[
Tuple
[
str
,
Dict
[
str
,
np
.
ndarray
]]],
float_pad_value
:
Union
[
float
,
int
]
=
0.0
,
int_pad_value
:
int
=
-
32768
,
not_sequence
:
Collection
[
str
]
=
(),
mlm_prob
:
float
=
0.8
,
mean_phn_span
:
int
=
8
,
feats_extract
=
None
,
attention_window
:
int
=
0
,
pad_speech
:
bool
=
False
,
sega_emb
:
bool
=
False
,
duration_collect
:
bool
=
False
,
text_masking
:
bool
=
False
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
paddle
.
Tensor
]]:
"""Concatenate ndarray-list to an array and convert to paddle.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
uttids
=
[
u
for
u
,
_
in
data
]
data
=
[
d
for
_
,
d
in
data
]
assert
all
(
set
(
data
[
0
])
==
set
(
d
)
for
d
in
data
),
"dict-keys mismatching"
assert
all
(
not
k
.
endswith
(
"_lengths"
)
for
k
in
data
[
0
]
),
f
"*_lengths is reserved:
{
list
(
data
[
0
])
}
"
output
=
{}
for
key
in
data
[
0
]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if
data
[
0
][
key
].
dtype
.
kind
==
"i"
:
pad_value
=
int_pad_value
else
:
pad_value
=
float_pad_value
array_list
=
[
d
[
key
]
for
d
in
data
]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list
=
[
paddle
.
to_tensor
(
a
)
for
a
in
array_list
]
# tensor: (Batch, Length, ...)
tensor
=
pad_list
(
tensor_list
,
pad_value
)
output
[
key
]
=
tensor
# lens: (Batch,)
if
key
not
in
not_sequence
:
lens
=
paddle
.
to_tensor
([
d
[
key
].
shape
[
0
]
for
d
in
data
],
dtype
=
paddle
.
long
)
output
[
key
+
"_lengths"
]
=
lens
f
=
open
(
'tmp_var.out'
,
'w'
)
for
item
in
[
round
(
item
,
6
)
for
item
in
output
[
"speech"
][
0
].
tolist
()]:
f
.
write
(
str
(
item
)
+
'
\n
'
)
feats
=
feats_extract
.
get_log_mel_fbank
(
np
.
array
(
output
[
"speech"
][
0
]))
feats
=
paddle
.
to_tensor
(
feats
)
print
(
'out shape'
,
paddle
.
shape
(
feats
))
feats_lengths
=
paddle
.
shape
(
feats
)[
0
]
feats
=
paddle
.
unsqueeze
(
feats
,
0
)
batch_size
=
paddle
.
shape
(
feats
)[
0
]
if
'text'
not
in
output
:
text
=
paddle
.
zeros_like
(
feats_lengths
.
unsqueeze
(
-
1
))
-
2
text_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
+
1
max_tlen
=
1
align_start
=
paddle
.
zeros_like
(
text
)
align_end
=
paddle
.
zeros_like
(
text
)
align_start_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
align_end_lengths
=
paddle
.
zeros_like
(
feats_lengths
)
sega_emb
=
False
mean_phn_span
=
0
mlm_prob
=
0.15
else
:
text
,
text_lengths
=
output
[
"text"
],
output
[
"text_lengths"
]
align_start
,
align_start_lengths
,
align_end
,
align_end_lengths
=
output
[
"align_start"
],
output
[
"align_start_lengths"
],
output
[
"align_end"
],
output
[
"align_end_lengths"
]
align_start
=
paddle
.
floor
(
feats_extract
.
sr
*
align_start
/
feats_extract
.
hop_length
).
int
()
align_end
=
paddle
.
floor
(
feats_extract
.
sr
*
align_end
/
feats_extract
.
hop_length
).
int
()
max_tlen
=
max
(
text_lengths
).
item
()
max_slen
=
max
(
feats_lengths
).
item
()
speech_pad
=
feats
[:,
:
max_slen
]
if
attention_window
>
0
and
pad_speech
:
speech_pad
,
max_slen
=
pad_to_longformer_att_window
(
speech_pad
,
max_slen
,
max_slen
,
attention_window
)
max_len
=
max_slen
+
max_tlen
if
attention_window
>
0
:
text_pad
,
max_tlen
=
pad_to_longformer_att_window
(
text
,
max_len
,
max_tlen
,
attention_window
)
else
:
text_pad
=
text
text_mask
=
make_non_pad_mask
(
text_lengths
.
tolist
(),
text_pad
,
length_dim
=
1
).
unsqueeze
(
-
2
)
if
attention_window
>
0
:
text_mask
=
text_mask
*
2
speech_mask
=
make_non_pad_mask
(
feats_lengths
.
tolist
(),
speech_pad
[:,:,
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
span_boundary
=
None
if
'span_boundary'
in
output
.
keys
():
span_boundary
=
output
[
'span_boundary'
]
if
text_masking
:
masked_position
,
text_masked_position
,
_
=
phones_text_masking
(
speech_pad
,
speech_mask
,
text_pad
,
text_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
)
else
:
text_masked_position
=
np
.
zeros
(
text_pad
.
size
())
masked_position
,
_
=
phones_masking
(
speech_pad
,
speech_mask
,
align_start
,
align_end
,
align_start_lengths
,
mlm_prob
,
mean_phn_span
,
span_boundary
)
output_dict
=
{}
if
duration_collect
and
'text'
in
output
:
reordered_index
,
speech_segment_pos
,
text_segment_pos
,
durations
,
feats_lengths
=
get_segment_pos_reduce_duration
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
sega_emb
,
masked_position
,
feats_lengths
)
speech_mask
=
make_non_pad_mask
(
feats_lengths
.
tolist
(),
speech_pad
[:,:
reordered_index
.
shape
[
1
],
0
],
length_dim
=
1
).
unsqueeze
(
-
2
)
output_dict
[
'durations'
]
=
durations
output_dict
[
'reordered_index'
]
=
reordered_index
else
:
speech_segment_pos
,
text_segment_pos
=
get_segment_pos
(
speech_pad
,
text_pad
,
align_start
,
align_end
,
align_start_lengths
,
sega_emb
)
output_dict
[
'speech'
]
=
speech_pad
output_dict
[
'text'
]
=
text_pad
output_dict
[
'masked_position'
]
=
masked_position
output_dict
[
'text_masked_position'
]
=
text_masked_position
output_dict
[
'speech_mask'
]
=
speech_mask
output_dict
[
'text_mask'
]
=
text_mask
output_dict
[
'speech_segment_pos'
]
=
speech_segment_pos
output_dict
[
'text_segment_pos'
]
=
text_segment_pos
# output_dict['y_masks'] = y_masks
output_dict
[
'speech_lengths'
]
=
output
[
"speech_lengths"
]
output_dict
[
'text_lengths'
]
=
text_lengths
output
=
(
uttids
,
output_dict
)
# assert check_return_type(output)
return
output
def
build_collate_fn
(
args
:
argparse
.
Namespace
,
train
:
bool
,
epoch
=-
1
):
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class
=
LogMelFBank
args_dic
=
{}
print
(
'type is'
,
type
(
args
.
feats_extract_conf
))
for
k
,
v
in
args
.
feats_extract_conf
.
items
():
if
k
==
'fs'
:
args_dic
[
'sr'
]
=
v
else
:
args_dic
[
k
]
=
v
# feats_extract = feats_extract_class(**args.feats_extract_conf)
feats_extract
=
feats_extract_class
(
**
args_dic
)
sega_emb
=
True
if
args
.
encoder_conf
[
'input_layer'
]
==
'sega_mlm'
else
False
if
args
.
encoder_conf
[
'selfattention_layer_type'
]
==
'longformer'
:
attention_window
=
args
.
encoder_conf
[
'attention_window'
]
pad_speech
=
True
if
'pre_speech_layer'
in
args
.
encoder_conf
and
args
.
encoder_conf
[
'pre_speech_layer'
]
>
0
else
False
else
:
attention_window
=
0
pad_speech
=
False
if
epoch
==-
1
:
mlm_prob_factor
=
1
else
:
mlm_probs
=
[
1.0
,
1.0
,
0.7
,
0.6
,
0.5
]
mlm_prob_factor
=
0.8
#mlm_probs[epoch // 100]
if
'duration_predictor_layers'
in
args
.
model_conf
.
keys
()
and
args
.
model_conf
[
'duration_predictor_layers'
]
>
0
:
duration_collect
=
True
else
:
duration_collect
=
False
return
MLMCollateFn
(
feats_extract
,
float_pad_value
=
0.0
,
int_pad_value
=
0
,
mlm_prob
=
args
.
model_conf
[
'mlm_prob'
]
*
mlm_prob_factor
,
mean_phn_span
=
args
.
model_conf
[
'mean_phn_span'
],
attention_window
=
attention_window
,
pad_speech
=
pad_speech
,
sega_emb
=
sega_emb
,
duration_collect
=
duration_collect
)
def
get_mlm_output
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
None
,
decoder
=
False
,
use_teacher_forcing
=
False
,
dynamic_eval
=
(
0
,
0
),
duration_adjust
=
True
,
start_end_sp
=
False
):
mlm_model
,
train_args
=
load_model
(
model_name
)
mlm_model
.
eval
()
# processor = MLMTask.build_preprocess_fn(train_args, False)
processor
=
None
collate_fn
=
MLMTask
.
build_collate_fn
(
train_args
,
False
)
# collate_fn = build_collate_fn(train_args, False)
collate_fn
=
None
return
decode_with_model
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
mlm_model
,
processor
,
collate_fn
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
sid
=
sid
,
decoder
=
decoder
,
use_teacher_forcing
=
use_teacher_forcing
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
start_end_sp
,
train_args
=
train_args
)
def
prompt_decoding_fn
(
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
vocoder
,
duration_preditor_path
,
sid
=
None
,
non_autoreg
=
True
,
dynamic_eval
=
(
0
,
0
),
duration_adjust
=
True
):
wav_org
,
input_feat
,
output_feat
,
old_span_boundary
,
new_span_boundary
,
fs
,
hop_length
=
get_mlm_output
(
model_name
,
wav_path
,
old_str
,
new_str
,
duration_preditor_path
,
use_teacher_forcing
=
non_autoreg
,
sid
=
sid
,
dynamic_eval
=
dynamic_eval
,
duration_adjust
=
duration_adjust
,
start_end_sp
=
False
)
replaced_wav
=
vocoder
(
output_feat
).
detach
().
float
().
data
.
cpu
().
numpy
()
old_time_boundary
=
[
hop_length
*
x
for
x
in
old_span_boundary
]
new_time_boundary
=
[
hop_length
*
x
for
x
in
new_span_boundary
]
new_wav
=
replaced_wav
[
new_time_boundary
[
0
]:]
# "origin_vocoder":vocoder_origin_wav,
data_dict
=
{
"prompt"
:
wav_org
,
"new_wav"
:
new_wav
}
return
data_dict
def
test_vctk
(
uid
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
vocoder
,
prefix
=
'dump/raw/dev'
,
model_name
=
"conformer"
,
old_str
=
""
,
new_str
=
""
,
prompt_decoding
=
False
,
dynamic_eval
=
(
0
,
0
),
task_name
=
None
):
new_str
=
new_str
.
strip
()
if
clone_uid
is
not
None
and
clone_prefix
is
not
None
:
if
target_language
==
"english"
:
duration_preditor_path
=
duration_path_dict
[
'ljspeech'
]
elif
target_language
==
"chinese"
:
duration_preditor_path
=
duration_path_dict
[
'ljspeech'
]
else
:
assert
target_language
==
"chinese"
or
target_language
==
"english"
,
"duration_preditor_path is not support for this language..."
else
:
duration_preditor_path
=
duration_path_dict
[
'ljspeech'
]
duration_preditor_path
=
None
spemd
=
None
full_origin_str
,
wav_path
=
read_data
(
uid
,
prefix
)
...
...
@@ -1060,12 +454,7 @@ def test_vctk(uid, clone_uid, clone_prefix, source_language, target_language, vo
if
not
old_str
:
old_str
=
full_origin_str
if
not
new_str
:
new_str
=
input
(
"input the new string:"
)
if
prompt_decoding
:
print
(
new_str
)
return
prompt_decoding_fn
(
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
vocoder
,
duration_preditor_path
,
sid
=
spemd
,
dynamic_eval
=
dynamic_eval
)
print
(
full_origin_str
)
results_dict
,
old_span
=
plot_mel_and_vocode_wav
(
uid
,
prefix
,
clone_uid
,
clone_prefix
,
source_language
,
target_language
,
model_name
,
wav_path
,
full_origin_str
,
old_str
,
new_str
,
vocoder
,
duration_preditor_path
,
sid
=
spemd
)
return
results_dict
...
...
@@ -1083,4 +472,4 @@ if __name__ == "__main__":
new_str
=
args
.
new_str
,
task_name
=
args
.
task_name
)
sf
.
write
(
'./wavs/%s'
%
args
.
output_name
,
data_dict
[
'output'
],
samplerate
=
24000
)
\ No newline at end of file
# exit()
ernie-sat/prompt/dev/mfa_end
浏览文件 @
c2f4709e
p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525
Prompt_003_new 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625 1.3125
p299_096 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325 2.6825
p243_new 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125 3.4525
ernie-sat/prompt/dev/mfa_start
浏览文件 @
c2f4709e
Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625
p243_new 0.0125 1.0225 1.0525 1.0925 1.1325 1.1725 1.2625 1.3625 1.4125 1.5125 1.6225 1.6625 1.7925 1.8625 2.0025 2.0925 2.1725 2.2625 2.4325 2.4725 2.5225 2.5825 2.6125 2.6425 2.7425 2.8025 2.9025 2.9525 3.0525 3.0825 3.2125
Prompt_003_new 0.0125 0.0425 0.0925 0.1825 0.2125 0.2425 0.3225 0.3725 0.4725 0.5325 0.5625 0.6225 0.7425 0.8625 0.9725 0.9975 1.0125 1.0825 1.2625
p299_096 0.0125 0.7525 0.7925 0.8725 0.9125 0.9425 1.0325 1.0625 1.1925 1.2625 1.3225 1.3725 1.4125 1.5125 1.5425 1.6525 1.6925 1.7325 1.7625 1.8425 1.9625 2.0225 2.1825 2.3325
ernie-sat/prompt/dev/mfa_text
浏览文件 @
c2f4709e
p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp
Prompt_003_new DH IH1 S W AA1 Z N AA1 T DH AH0 SH OW1 F AO1 R M IY1 sp
p299_096 sp W IY1 AA1 R T R AY1 NG T UW1 AH0 S T AE1 B L IH0 SH AH0 D EY1 T sp
p243_new sp F AO1 R DH AE1 T R IY1 Z AH0 N sp K AH1 V ER0 SH UH1 D N AA1 T B IY1 G IH1 V AH0 N sp
ernie-sat/prompt/dev/mfa_wav.scp
浏览文件 @
c2f4709e
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
ernie-sat/prompt/dev/text
浏览文件 @
c2f4709e
Prompt_003_new This was not the show for me.
p243_new For that reason cover should not be given.
Prompt_003_new This was not the show for me.
p299_096 We are trying to establish a date.
ernie-sat/prompt/dev/wav.scp
浏览文件 @
c2f4709e
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
p243_new ../../prompt_wav/p243_313.wav
ernie-sat/run_clone_en_to_zh.sh
浏览文件 @
c2f4709e
# en --> zh 的 clone
python sedit_inference_0520.py
\
# en --> zh 的 语音合成
# 根据Prompt_003_new对应的语音: This was not the show for me. 来合成: '今天天气很好'
python inference.py
\
--task_name
cross-lingual_clone
\
--model_name
paddle_checkpoint_ench
\
--uid
Prompt_003_new
\
--new_str
'今天天气很好'
\
--prefix
./prompt/dev/
\
--clone_prefix
./prompt/dev_aishell3/
\
--clone_uid
SSB07510054
\
--source_language
english
\
--target_language
chinese
\
--output_name
task_cross_lingual_pred
.wav
\
--output_name
pred_zh
.wav
\
--voc
pwgan_aishell3
\
--voc_config
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
...
...
ernie-sat/run_gen_en.sh
浏览文件 @
c2f4709e
# 纯英文的语音合成
# python sedit_inference_0518.py \
# --task_name synthesize \
# --model_name paddle_checkpoint_en \
# --uid p323_083 \
# --new_str 'I enjoy my life.' \
# --prefix ./prompt/dev/ \
# --source_language english \
# --target_language english \
# --output_name pred.wav \
# --voc pwgan_aishell3 \
# --voc_config download/pwg_aishell3_ckpt_0.5/default.yaml \
# --voc_ckpt download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
# --voc_stat download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
# --am fastspeech2_ljspeech \
# --am_config download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
# --am_ckpt download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
# --am_stat download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
# --phones_dict download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
# 根据p299_096对应的语音: This was not the show for me. 来合成: 'I enjoy my life.'
# 纯英文的语音合成
python sedit_inference_0520.py
\
python inference.py
\
--task_name
synthesize
\
--model_name
paddle_checkpoint_en
\
--uid
p299_096
\
...
...
@@ -28,7 +9,7 @@ python sedit_inference_0520.py \
--prefix
./prompt/dev/
\
--source_language
english
\
--target_language
english
\
--output_name
task_synthesize_
pred.wav
\
--output_name
pred.wav
\
--voc
pwgan_aishell3
\
--voc_config
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
...
...
ernie-sat/run_sedit_en.sh
浏览文件 @
c2f4709e
# 纯英文的语音编辑
python sedit_inference_0520.py
\
# 将p243_new对应的原始语音: For that reason cover should not be given. 编辑成'for that reason cover is impossible to be given.'对应的语音
python inference.py
\
--task_name
edit
\
--model_name
paddle_checkpoint_en
\
--uid
p243_new
\
...
...
@@ -7,7 +9,7 @@ python sedit_inference_0520.py \
--prefix
./prompt/dev/
\
--source_language
english
\
--target_language
english
\
--output_name
task_edit_
pred.wav
\
--output_name
pred.wav
\
--voc
pwgan_aishell3
\
--voc_config
download/pwg_aishell3_ckpt_0.5/default.yaml
\
--voc_ckpt
download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz
\
...
...
ernie-sat/tmp/tmp_pkl.Prompt_003_new
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/tmp/tmp_pkl.p243_new
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/tmp/tmp_pkl.p299_096
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/util.py
→
ernie-sat/util
s
.py
浏览文件 @
c2f4709e
...
...
@@ -70,13 +70,15 @@ def get_voc_out(mel, target_language="chinese"):
print
(
"current vocoder: "
,
args
.
voc
)
with
open
(
args
.
voc_config
)
as
f
:
voc_config
=
CfgNode
(
yaml
.
safe_load
(
f
))
# print(voc_config)
voc_inference
=
get_voc_inference
(
args
,
voc_config
)
mel
=
paddle
.
to_tensor
(
mel
)
# print("masked_mel: ", mel.shape)
with
paddle
.
no_grad
():
wav
=
voc_inference
(
mel
)
print
(
"shepe of wav (time x n_channels):%s"
%
wav
.
shape
)
# (31800,1)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return
np
.
squeeze
(
wav
)
# dygraph
...
...
@@ -134,6 +136,7 @@ def get_am_inference(args, am_config):
def
evaluate_durations
(
phns
,
target_language
=
"chinese"
,
fs
=
24000
,
hop_length
=
300
):
args
=
parse_args
()
# args = parser.parse_args(args=[])
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
...
...
@@ -154,6 +157,7 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
# acoustic model
am
,
am_inference
,
am_name
,
am_dataset
,
phn_id
=
get_am_inference
(
args
,
am_config
)
torch_phns
=
phns
vocab_phones
=
{}
for
tone
,
id
in
phn_id
:
...
...
@@ -165,17 +169,16 @@ def evaluate_durations(phns, target_language="chinese", fs=24000, hop_length=300
]
phone_ids
=
[
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids_new
=
phone_ids
phone_ids_new
.
append
(
vocab_size
-
1
)
phone_ids_new
=
paddle
.
to_tensor
(
np
.
array
(
phone_ids_new
,
np
.
int64
))
normalized_mel
,
d_outs
,
p_outs
,
e_outs
=
am
.
inference
(
phone_ids_new
,
spk_id
=
None
,
spk_emb
=
None
)
pre_d_outs
=
d_outs
phoneme_durations_new
=
pre_d_outs
*
hop_length
/
fs
phoneme_durations_new
=
phoneme_durations_new
.
tolist
()[:
-
1
]
return
phoneme_durations_new
def
sentence2phns
(
sentence
,
target_language
=
"en"
):
args
=
parse_args
()
if
target_language
==
'en'
:
...
...
ernie-sat/wavs/ori.wav
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
ernie-sat/wavs/pred.wav
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/wavs/pred_en_edit_paddle_voc.wav
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/wavs/pred_zh.wav
浏览文件 @
c2f4709e
无法预览此类型文件
ernie-sat/wavs/pred_zh_fst2_voc.wav
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
ernie-sat/wavs/task_cross_lingual_pred.wav
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
ernie-sat/wavs/task_edit_pred.wav
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
ernie-sat/wavs/task_synthesize_pred.wav
已删除
100644 → 0
浏览文件 @
75dae192
文件已删除
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录