Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
c170b5a2
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
11
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c170b5a2
编写于
6月 12, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'commit' into 'master'
fix some bugs of transformer_tts and fastspeech. See merge request !58
上级
33ed693c
681d34b9
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
152 addition
and
148 deletion
+152
-148
examples/fastspeech/alignments/get_alignments.py
examples/fastspeech/alignments/get_alignments.py
+1
-6
examples/fastspeech/configs/ljspeech.yaml
examples/fastspeech/configs/ljspeech.yaml
+1
-1
examples/fastspeech/train.py
examples/fastspeech/train.py
+51
-48
examples/transformer_tts/README.md
examples/transformer_tts/README.md
+1
-1
examples/transformer_tts/configs/ljspeech.yaml
examples/transformer_tts/configs/ljspeech.yaml
+1
-1
examples/transformer_tts/train_transformer.py
examples/transformer_tts/train_transformer.py
+95
-90
parakeet/models/fastspeech/length_regulator.py
parakeet/models/fastspeech/length_regulator.py
+2
-1
未找到文件。
examples/fastspeech/alignments/get_alignments.py
浏览文件 @
c170b5a2
...
...
@@ -115,15 +115,10 @@ def alignments(args):
mel_input
=
fluid
.
layers
.
unsqueeze
(
dg
.
to_variable
(
mel_input
),
[
0
])
mel_lens
=
mel_input
.
shape
[
1
]
dec_slf_mask
=
get_triu_tensor
(
mel_input
,
mel_input
).
astype
(
np
.
float32
)
dec_slf_mask
=
np
.
expand_dims
(
dec_slf_mask
,
axis
=
0
)
dec_slf_mask
=
fluid
.
layers
.
cast
(
dg
.
to_variable
(
dec_slf_mask
!=
0
),
np
.
float32
)
*
(
-
2
**
32
+
1
)
pos_mel
=
np
.
arange
(
1
,
mel_input
.
shape
[
1
]
+
1
)
pos_mel
=
fluid
.
layers
.
unsqueeze
(
dg
.
to_variable
(
pos_mel
),
[
0
])
mel_pred
,
postnet_pred
,
attn_probs
,
stop_preds
,
attn_enc
,
attn_dec
=
model
(
text
,
mel_input
,
pos_text
,
pos_mel
,
dec_slf_mask
)
text
,
mel_input
,
pos_text
,
pos_mel
)
mel_input
=
fluid
.
layers
.
concat
(
[
mel_input
,
postnet_pred
[:,
-
1
:,
:]],
axis
=
1
)
...
...
examples/fastspeech/configs/ljspeech.yaml
浏览文件 @
c170b5a2
...
...
@@ -29,5 +29,5 @@ train:
grad_clip_thresh
:
0.1
#the threshold of grad clip.
checkpoint_interval
:
1000
max_
epochs
:
1
0000
max_
iteration
:
50
0000
examples/fastspeech/train.py
浏览文件 @
c170b5a2
...
...
@@ -62,7 +62,8 @@ def main(args):
cfg
=
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
)
global_step
=
0
place
=
fluid
.
CUDAPlace
(
local_rank
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
()
.
dev_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
fluid
.
enable_dygraph
(
place
)
if
not
os
.
path
.
exists
(
args
.
output
):
...
...
@@ -88,7 +89,8 @@ def main(args):
cfg
[
'train'
][
'batch_size'
],
nranks
,
local_rank
,
shuffle
=
True
).
reader
()
shuffle
=
True
).
reader
iterator
=
iter
(
tqdm
(
reader
))
# Load parameters.
global_step
=
io
.
load_parameters
(
...
...
@@ -103,12 +105,14 @@ def main(args):
strategy
=
dg
.
parallel
.
prepare_context
()
model
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
model
,
strategy
)
for
epoch
in
range
(
cfg
[
'train'
][
'max_epochs'
]):
pbar
=
tqdm
(
reader
)
while
global_step
<=
cfg
[
'train'
][
'max_iteration'
]:
try
:
batch
=
next
(
iterator
)
except
StopIteration
as
e
:
iterator
=
iter
(
tqdm
(
reader
))
batch
=
next
(
iterator
)
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
(
character
,
mel
,
pos_text
,
pos_mel
,
alignment
)
=
data
(
character
,
mel
,
pos_text
,
pos_mel
,
alignment
)
=
batch
global_step
+=
1
...
...
@@ -120,8 +124,7 @@ def main(args):
mel_postnet_loss
=
layers
.
mse_loss
(
mel_output_postnet
,
mel
)
duration_loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
duration_predictor_output
,
alignment
)))
layers
.
elementwise_sub
(
duration_predictor_output
,
alignment
)))
total_loss
=
mel_loss
+
mel_postnet_loss
+
duration_loss
if
local_rank
==
0
:
...
...
@@ -147,8 +150,8 @@ def main(args):
if
local_rank
==
0
and
global_step
%
cfg
[
'train'
][
'checkpoint_interval'
]
==
0
:
io
.
save_parameters
(
os
.
path
.
join
(
args
.
output
,
'checkpoints'
),
global_step
,
model
,
optimizer
)
os
.
path
.
join
(
args
.
output
,
'checkpoints'
),
global_step
,
model
,
optimizer
)
if
local_rank
==
0
:
writer
.
close
()
...
...
examples/transformer_tts/README.md
浏览文件 @
c170b5a2
...
...
@@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
TransformerTTS model can be trained by running
``train_transformer.py``
.
```
bash
python train_trasformer.py
\
python train_tra
n
sformer.py
\
--use_gpu
=
1
\
--data
=
${
DATAPATH
}
\
--output
=
'./experiment'
\
...
...
examples/transformer_tts/configs/ljspeech.yaml
浏览文件 @
c170b5a2
...
...
@@ -31,7 +31,7 @@ train:
checkpoint_interval
:
1000
image_interval
:
2000
max_
epochs
:
1
0000
max_
iteration
:
50
0000
examples/transformer_tts/train_transformer.py
浏览文件 @
c170b5a2
...
...
@@ -102,16 +102,21 @@ def main(args):
cfg
[
'train'
][
'batch_size'
],
nranks
,
local_rank
,
shuffle
=
True
).
reader
()
shuffle
=
True
).
reader
for
epoch
in
range
(
cfg
[
'train'
][
'max_epochs'
]):
pbar
=
tqdm
(
reader
)
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
character
,
mel
,
mel_input
,
pos_text
,
pos_mel
=
data
iterator
=
iter
(
tqdm
(
reader
))
global_step
+=
1
while
global_step
<=
cfg
[
'train'
][
'max_iteration'
]:
try
:
batch
=
next
(
iterator
)
except
StopIteration
as
e
:
iterator
=
iter
(
tqdm
(
reader
))
batch
=
next
(
iterator
)
character
,
mel
,
mel_input
,
pos_text
,
pos_mel
=
batch
mel_pred
,
postnet_pred
,
attn_probs
,
stop_preds
,
attn_enc
,
attn_dec
=
model
(
character
,
mel_input
,
pos_text
,
pos_mel
)
...
...
@@ -134,8 +139,7 @@ def main(args):
},
global_step
)
if
cfg
[
'network'
][
'stop_token'
]:
writer
.
add_scalar
(
'stop_loss'
,
stop_loss
.
numpy
(),
global_step
)
writer
.
add_scalar
(
'stop_loss'
,
stop_loss
.
numpy
(),
global_step
)
if
parallel
:
writer
.
add_scalars
(
'alphas'
,
{
...
...
@@ -157,7 +161,7 @@ def main(args):
for
j
in
range
(
cfg
[
'network'
][
'decoder_num_head'
]):
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
cfg
[
'train'
][
'batch_size'
]
//
2
])
*
255
)
'batch_size'
]
//
nranks
])
*
255
)
writer
.
add_image
(
'Attention_%d_0'
%
global_step
,
x
,
...
...
@@ -168,7 +172,7 @@ def main(args):
for
j
in
range
(
cfg
[
'network'
][
'encoder_num_head'
]):
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
cfg
[
'train'
][
'batch_size'
]
//
2
])
*
255
)
'batch_size'
]
//
nranks
])
*
255
)
writer
.
add_image
(
'Attention_enc_%d_0'
%
global_step
,
x
,
...
...
@@ -179,7 +183,7 @@ def main(args):
for
j
in
range
(
cfg
[
'network'
][
'decoder_num_head'
]):
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
cfg
[
'train'
][
'batch_size'
]
//
2
])
*
255
)
'batch_size'
]
//
nranks
])
*
255
)
writer
.
add_image
(
'Attention_dec_%d_0'
%
global_step
,
x
,
...
...
@@ -199,8 +203,9 @@ def main(args):
if
local_rank
==
0
and
global_step
%
cfg
[
'train'
][
'checkpoint_interval'
]
==
0
:
io
.
save_parameters
(
os
.
path
.
join
(
args
.
output
,
'checkpoints'
),
global_step
,
model
,
optimizer
)
os
.
path
.
join
(
args
.
output
,
'checkpoints'
),
global_step
,
model
,
optimizer
)
global_step
+=
1
if
local_rank
==
0
:
writer
.
close
()
...
...
parakeet/models/fastspeech/length_regulator.py
浏览文件 @
c170b5a2
...
...
@@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
else
:
duration_predictor_output
=
layers
.
round
(
duration_predictor_output
)
output
=
self
.
LR
(
x
,
duration_predictor_output
,
alpha
)
mel_pos
=
dg
.
to_variable
(
np
.
arange
(
1
,
output
.
shape
[
1
]
+
1
))
mel_pos
=
dg
.
to_variable
(
np
.
arange
(
1
,
output
.
shape
[
1
]
+
1
)).
astype
(
np
.
int64
)
mel_pos
=
layers
.
unsqueeze
(
mel_pos
,
[
0
])
return
output
,
mel_pos
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录