Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7b2ecb6e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“ec3bb70c7ca09d49c10c07b72dd7acaa7369b765”上不存在“mobile/src/operators/reshape_op.cpp”
提交
7b2ecb6e
编写于
12月 03, 2021
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add style_melgan, test=tts
上级
075aeee7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
46 addition
and
40 deletion
+46
-40
paddlespeech/t2s/models/melgan/style_melgan_updater.py
paddlespeech/t2s/models/melgan/style_melgan_updater.py
+46
-40
未找到文件。
paddlespeech/t2s/models/melgan/style_melgan_updater.py
浏览文件 @
7b2ecb6e
...
@@ -40,8 +40,9 @@ class StyleMelGANUpdater(StandardUpdater):
...
@@ -40,8 +40,9 @@ class StyleMelGANUpdater(StandardUpdater):
criterions
:
Dict
[
str
,
Layer
],
criterions
:
Dict
[
str
,
Layer
],
schedulers
:
Dict
[
str
,
LRScheduler
],
schedulers
:
Dict
[
str
,
LRScheduler
],
dataloader
:
DataLoader
,
dataloader
:
DataLoader
,
discriminator_train_start_steps
:
int
,
generator_train_start_steps
:
int
=
0
,
lambda_adv
:
float
,
discriminator_train_start_steps
:
int
=
100000
,
lambda_adv
:
float
=
1.0
,
lambda_aux
:
float
=
1.0
,
lambda_aux
:
float
=
1.0
,
output_dir
:
Path
=
None
):
output_dir
:
Path
=
None
):
self
.
models
=
models
self
.
models
=
models
...
@@ -63,11 +64,12 @@ class StyleMelGANUpdater(StandardUpdater):
...
@@ -63,11 +64,12 @@ class StyleMelGANUpdater(StandardUpdater):
self
.
dataloader
=
dataloader
self
.
dataloader
=
dataloader
self
.
generator_train_start_steps
=
generator_train_start_steps
self
.
discriminator_train_start_steps
=
discriminator_train_start_steps
self
.
discriminator_train_start_steps
=
discriminator_train_start_steps
self
.
lambda_adv
=
lambda_adv
self
.
lambda_adv
=
lambda_adv
self
.
lambda_aux
=
lambda_aux
self
.
lambda_aux
=
lambda_aux
self
.
state
=
UpdaterState
(
iteration
=
0
,
epoch
=
0
)
self
.
state
=
UpdaterState
(
iteration
=
0
,
epoch
=
0
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
...
@@ -79,42 +81,45 @@ class StyleMelGANUpdater(StandardUpdater):
...
@@ -79,42 +81,45 @@ class StyleMelGANUpdater(StandardUpdater):
def
update_core
(
self
,
batch
):
def
update_core
(
self
,
batch
):
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
losses_dict
=
{}
losses_dict
=
{}
# parse batch
# parse batch
wav
,
mel
=
batch
wav
,
mel
=
batch
# Generator
# Generator
# (B, out_channels, T ** prod(upsample_scales)
if
self
.
state
.
iteration
>
self
.
generator_train_start_steps
:
wav_
=
self
.
generator
(
mel
)
# (B, out_channels, T ** prod(upsample_scales)
wav_
=
self
.
generator
(
mel
)
# initialize
# initialize
gen_loss
=
0.0
gen_loss
=
0.0
aux_loss
=
0.0
# full band M
ulti-resolution stft loss
# full band m
ulti-resolution stft loss
sc_loss
,
mag_loss
=
self
.
criterion_stft
(
wav_
,
wav
)
sc_loss
,
mag_loss
=
self
.
criterion_stft
(
wav_
,
wav
)
gen
_loss
+=
sc_loss
+
mag_loss
aux
_loss
+=
sc_loss
+
mag_loss
report
(
"train/spectral_convergence_loss"
,
float
(
sc_loss
))
report
(
"train/spectral_convergence_loss"
,
float
(
sc_loss
))
report
(
"train/log_stft_magnitude_loss"
,
float
(
mag_loss
))
report
(
"train/log_stft_magnitude_loss"
,
float
(
mag_loss
))
losses_dict
[
"spectral_convergence_loss"
]
=
float
(
sc_loss
)
losses_dict
[
"spectral_convergence_loss"
]
=
float
(
sc_loss
)
losses_dict
[
"log_stft_magnitude_loss"
]
=
float
(
mag_loss
)
losses_dict
[
"log_stft_magnitude_loss"
]
=
float
(
mag_loss
)
gen_loss
*=
self
.
lambda_aux
gen_loss
+=
aux_loss
*
self
.
lambda_aux
## Adversarial loss
# adversarial loss
if
self
.
state
.
iteration
>
self
.
discriminator_train_start_steps
:
if
self
.
state
.
iteration
>
self
.
discriminator_train_start_steps
:
p_
=
self
.
discriminator
(
wav_
)
p_
=
self
.
discriminator
(
wav_
)
adv_loss
=
self
.
criterion_gen_adv
(
p_
)
adv_loss
=
self
.
criterion_gen_adv
(
p_
)
report
(
"train/adversarial_loss"
,
float
(
adv_loss
))
report
(
"train/adversarial_loss"
,
float
(
adv_loss
))
losses_dict
[
"adversarial_loss"
]
=
float
(
adv_loss
)
losses_dict
[
"adversarial_loss"
]
=
float
(
adv_loss
)
gen_loss
+=
self
.
lambda_adv
*
adv_loss
report
(
"train/generator_loss"
,
float
(
gen_loss
))
gen_loss
+=
self
.
lambda_adv
*
adv_loss
losses_dict
[
"generator_loss"
]
=
float
(
gen_loss
)
self
.
optimizer_g
.
clear_grad
(
)
report
(
"train/generator_loss"
,
float
(
gen_loss
)
)
gen_loss
.
backward
(
)
losses_dict
[
"generator_loss"
]
=
float
(
gen_loss
)
self
.
optimizer_g
.
step
()
self
.
optimizer_g
.
clear_grad
()
self
.
scheduler_g
.
step
()
gen_loss
.
backward
()
self
.
optimizer_g
.
step
()
self
.
scheduler_g
.
step
()
# Disctiminator
# Disctiminator
if
self
.
state
.
iteration
>
self
.
discriminator_train_start_steps
:
if
self
.
state
.
iteration
>
self
.
discriminator_train_start_steps
:
...
@@ -148,7 +153,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
...
@@ -148,7 +153,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
models
:
Dict
[
str
,
Layer
],
models
:
Dict
[
str
,
Layer
],
criterions
:
Dict
[
str
,
Layer
],
criterions
:
Dict
[
str
,
Layer
],
dataloader
:
DataLoader
,
dataloader
:
DataLoader
,
lambda_adv
:
float
,
lambda_adv
:
float
=
1.0
,
lambda_aux
:
float
=
1.0
,
lambda_aux
:
float
=
1.0
,
output_dir
:
Path
=
None
):
output_dir
:
Path
=
None
):
self
.
models
=
models
self
.
models
=
models
...
@@ -161,6 +166,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
...
@@ -161,6 +166,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
self
.
criterion_dis_adv
=
criterions
[
"dis_adv"
]
self
.
criterion_dis_adv
=
criterions
[
"dis_adv"
]
self
.
dataloader
=
dataloader
self
.
dataloader
=
dataloader
self
.
lambda_adv
=
lambda_adv
self
.
lambda_adv
=
lambda_adv
self
.
lambda_aux
=
lambda_aux
self
.
lambda_aux
=
lambda_aux
...
@@ -171,26 +177,27 @@ class StyleMelGANEvaluator(StandardEvaluator):
...
@@ -171,26 +177,27 @@ class StyleMelGANEvaluator(StandardEvaluator):
self
.
msg
=
""
self
.
msg
=
""
def
evaluate_core
(
self
,
batch
):
def
evaluate_core
(
self
,
batch
):
# logging.debug("Evaluate: ")
self
.
msg
=
"Evaluate: "
self
.
msg
=
"Evaluate: "
losses_dict
=
{}
losses_dict
=
{}
wav
,
mel
=
batch
wav
,
mel
=
batch
# Generator
# Generator
# (B, out_channels, T ** prod(upsample_scales)
# (B, out_channels, T ** prod(upsample_scales)
wav_
=
self
.
generator
(
mel
)
wav_
=
self
.
generator
(
mel
)
## Adversarial loss
# initialize
gen_loss
=
0.0
aux_loss
=
0.0
# adversarial loss
p_
=
self
.
discriminator
(
wav_
)
p_
=
self
.
discriminator
(
wav_
)
adv_loss
=
self
.
criterion_gen_adv
(
p_
)
adv_loss
=
self
.
criterion_gen_adv
(
p_
)
report
(
"eval/adversarial_loss"
,
float
(
adv_loss
))
report
(
"eval/adversarial_loss"
,
float
(
adv_loss
))
losses_dict
[
"adversarial_loss"
]
=
float
(
adv_loss
)
losses_dict
[
"adversarial_loss"
]
=
float
(
adv_loss
)
gen_loss
=
self
.
lambda_adv
*
adv_loss
# initialize
gen_loss
+=
self
.
lambda_adv
*
adv_loss
aux_loss
=
0.0
#
M
ulti-resolution stft loss
#
m
ulti-resolution stft loss
sc_loss
,
mag_loss
=
self
.
criterion_stft
(
wav_
,
wav
)
sc_loss
,
mag_loss
=
self
.
criterion_stft
(
wav_
,
wav
)
aux_loss
+=
sc_loss
+
mag_loss
aux_loss
+=
sc_loss
+
mag_loss
report
(
"eval/spectral_convergence_loss"
,
float
(
sc_loss
))
report
(
"eval/spectral_convergence_loss"
,
float
(
sc_loss
))
...
@@ -198,8 +205,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
...
@@ -198,8 +205,7 @@ class StyleMelGANEvaluator(StandardEvaluator):
losses_dict
[
"spectral_convergence_loss"
]
=
float
(
sc_loss
)
losses_dict
[
"spectral_convergence_loss"
]
=
float
(
sc_loss
)
losses_dict
[
"log_stft_magnitude_loss"
]
=
float
(
mag_loss
)
losses_dict
[
"log_stft_magnitude_loss"
]
=
float
(
mag_loss
)
aux_loss
*=
self
.
lambda_aux
gen_loss
+=
aux_loss
*
self
.
lambda_aux
gen_loss
+=
aux_loss
report
(
"eval/generator_loss"
,
float
(
gen_loss
))
report
(
"eval/generator_loss"
,
float
(
gen_loss
))
losses_dict
[
"generator_loss"
]
=
float
(
gen_loss
)
losses_dict
[
"generator_loss"
]
=
float
(
gen_loss
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录