Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
72fa8176
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
72fa8176
编写于
7月 13, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for mix_lang
上级
5503c8bd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
21 addition
and
10 deletion
+21
-10
paddlespeech/t2s/exps/ernie_sat/train.py
paddlespeech/t2s/exps/ernie_sat/train.py
+2
-0
paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py
paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py
+6
-2
paddlespeech/t2s/modules/losses.py
paddlespeech/t2s/modules/losses.py
+9
-5
paddlespeech/t2s/modules/nets_utils.py
paddlespeech/t2s/modules/nets_utils.py
+4
-3
未找到文件。
paddlespeech/t2s/exps/ernie_sat/train.py
浏览文件 @
72fa8176
...
@@ -154,6 +154,7 @@ def train_sp(args, config):
...
@@ -154,6 +154,7 @@ def train_sp(args, config):
dataloader
=
train_dataloader
,
dataloader
=
train_dataloader
,
text_masking
=
config
[
"model"
][
"text_masking"
],
text_masking
=
config
[
"model"
][
"text_masking"
],
odim
=
odim
,
odim
=
odim
,
vocab_size
=
vocab_size
,
output_dir
=
output_dir
)
output_dir
=
output_dir
)
trainer
=
Trainer
(
updater
,
(
config
.
max_epoch
,
'epoch'
),
output_dir
)
trainer
=
Trainer
(
updater
,
(
config
.
max_epoch
,
'epoch'
),
output_dir
)
...
@@ -163,6 +164,7 @@ def train_sp(args, config):
...
@@ -163,6 +164,7 @@ def train_sp(args, config):
dataloader
=
dev_dataloader
,
dataloader
=
dev_dataloader
,
text_masking
=
config
[
"model"
][
"text_masking"
],
text_masking
=
config
[
"model"
][
"text_masking"
],
odim
=
odim
,
odim
=
odim
,
vocab_size
=
vocab_size
,
output_dir
=
output_dir
,
)
output_dir
=
output_dir
,
)
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
...
...
paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py
浏览文件 @
72fa8176
...
@@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater):
...
@@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater):
init_state
=
None
,
init_state
=
None
,
text_masking
:
bool
=
False
,
text_masking
:
bool
=
False
,
odim
:
int
=
80
,
odim
:
int
=
80
,
vocab_size
:
int
=
100
,
output_dir
:
Path
=
None
):
output_dir
:
Path
=
None
):
super
().
__init__
(
model
,
optimizer
,
dataloader
,
init_state
=
None
)
super
().
__init__
(
model
,
optimizer
,
dataloader
,
init_state
=
None
)
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
criterion
=
MLMLoss
(
text_masking
=
text_masking
,
odim
=
odim
)
self
.
criterion
=
MLMLoss
(
text_masking
=
text_masking
,
odim
=
odim
,
vocab_size
=
vocab_size
)
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
self
.
filehandler
=
logging
.
FileHandler
(
str
(
log_file
))
self
.
filehandler
=
logging
.
FileHandler
(
str
(
log_file
))
...
@@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator):
...
@@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator):
dataloader
:
DataLoader
,
dataloader
:
DataLoader
,
text_masking
:
bool
=
False
,
text_masking
:
bool
=
False
,
odim
:
int
=
80
,
odim
:
int
=
80
,
vocab_size
:
int
=
100
,
output_dir
:
Path
=
None
):
output_dir
:
Path
=
None
):
super
().
__init__
(
model
,
dataloader
)
super
().
__init__
(
model
,
dataloader
)
...
@@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator):
...
@@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator):
self
.
logger
=
logger
self
.
logger
=
logger
self
.
msg
=
""
self
.
msg
=
""
self
.
criterion
=
MLMLoss
(
text_masking
=
text_masking
,
odim
=
odim
)
self
.
criterion
=
MLMLoss
(
text_masking
=
text_masking
,
odim
=
odim
,
vocab_size
=
vocab_size
)
def
evaluate_core
(
self
,
batch
):
def
evaluate_core
(
self
,
batch
):
self
.
msg
=
"Evaluate: "
self
.
msg
=
"Evaluate: "
...
...
paddlespeech/t2s/modules/losses.py
浏览文件 @
72fa8176
...
@@ -1013,6 +1013,7 @@ class KLDivergenceLoss(nn.Layer):
...
@@ -1013,6 +1013,7 @@ class KLDivergenceLoss(nn.Layer):
class
MLMLoss
(
nn
.
Layer
):
class
MLMLoss
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
odim
:
int
,
odim
:
int
,
vocab_size
:
int
=
0
,
lsm_weight
:
float
=
0.1
,
lsm_weight
:
float
=
0.1
,
ignore_id
:
int
=-
1
,
ignore_id
:
int
=-
1
,
text_masking
:
bool
=
False
):
text_masking
:
bool
=
False
):
...
@@ -1025,6 +1026,7 @@ class MLMLoss(nn.Layer):
...
@@ -1025,6 +1026,7 @@ class MLMLoss(nn.Layer):
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
l1_loss_func
=
nn
.
L1Loss
(
reduction
=
'none'
)
self
.
text_masking
=
text_masking
self
.
text_masking
=
text_masking
self
.
odim
=
odim
self
.
odim
=
odim
self
.
vocab_size
=
vocab_size
def
forward
(
def
forward
(
self
,
self
,
...
@@ -1059,10 +1061,12 @@ class MLMLoss(nn.Layer):
...
@@ -1059,10 +1061,12 @@ class MLMLoss(nn.Layer):
assert
text
is
not
None
assert
text
is
not
None
assert
text_outs
is
not
None
assert
text_outs
is
not
None
assert
text_masked_pos
is
not
None
assert
text_masked_pos
is
not
None
text_mlm_loss
=
paddle
.
sum
((
self
.
text_mlm_loss
(
text_outs
=
paddle
.
reshape
(
text_outs
,
[
-
1
,
self
.
vocab_size
])
paddle
.
reshape
(
text_outs
,
(
-
1
,
self
.
vocab_size
)),
text
=
paddle
.
reshape
(
text
,
[
-
1
])
paddle
.
reshape
(
text
,
(
-
1
)))
*
paddle
.
reshape
(
text_mlm_loss
=
self
.
text_mlm_loss
(
text_outs
,
text
)
text_masked_pos
,
text_masked_pos_reshape
=
paddle
.
reshape
(
text_masked_pos
,
[
-
1
])
(
-
1
))))
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
text_mlm_loss
=
paddle
.
sum
(
text_mlm_loss
*
text_masked_pos_reshape
)
/
paddle
.
sum
((
text_masked_pos
)
+
1e-10
)
return
mlm_loss
,
text_mlm_loss
return
mlm_loss
,
text_mlm_loss
paddlespeech/t2s/modules/nets_utils.py
浏览文件 @
72fa8176
...
@@ -464,14 +464,15 @@ def phones_text_masking(xs_pad: paddle.Tensor,
...
@@ -464,14 +464,15 @@ def phones_text_masking(xs_pad: paddle.Tensor,
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
set
(
range
(
length
))
-
set
(
masked_phn_idxs
[
0
].
tolist
()))
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
np
.
random
.
shuffle
(
unmasked_phn_idxs
)
masked_text_idxs
=
unmasked_phn_idxs
[:
text_mask_num_lower
]
masked_text_idxs
=
unmasked_phn_idxs
[:
text_mask_num_lower
]
text_masked_pos
[
idx
][
masked_text_idxs
]
=
1
text_masked_pos
[
idx
,
masked_text_idxs
]
=
1
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_start
=
align_start
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
masked_end
=
align_end
[
idx
][
masked_phn_idxs
].
tolist
()
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
for
s
,
e
in
zip
(
masked_start
,
masked_end
):
masked_pos
[
idx
,
s
:
e
]
=
1
masked_pos
[
idx
,
s
:
e
]
=
1
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
non_eos_mask
=
paddle
.
reshape
(
src_mask
,
shape
=
paddle
.
shape
(
xs_pad
)[:
2
])
masked_pos
=
masked_pos
*
non_eos_mask
masked_pos
=
masked_pos
*
non_eos_mask
non_eos_text_mask
=
paddle
.
reshape
(
text_mask
,
paddle
.
shape
(
xs_pad
)[:
2
])
non_eos_text_mask
=
paddle
.
reshape
(
text_mask
,
shape
=
paddle
.
shape
(
text_pad
)[:
2
])
text_masked_pos
=
text_masked_pos
*
non_eos_text_mask
text_masked_pos
=
text_masked_pos
*
non_eos_text_mask
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
masked_pos
=
paddle
.
cast
(
masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
text_masked_pos
=
paddle
.
cast
(
text_masked_pos
,
'bool'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录