Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
9fe6ad11
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9fe6ad11
编写于
12月 17, 2019
作者:
L
lifuchen
提交者:
chenfeiyu
12月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Training with multi-GPU
上级
8a9bbc26
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
390 addition
and
147 deletion
+390
-147
parakeet/audio/__init__.py
parakeet/audio/__init__.py
+1
-0
parakeet/audio/audio.py
parakeet/audio/audio.py
+261
-0
parakeet/data/datacargo.py
parakeet/data/datacargo.py
+13
-4
parakeet/models/transformerTTS/config/train_postnet.yaml
parakeet/models/transformerTTS/config/train_postnet.yaml
+1
-1
parakeet/models/transformerTTS/config/train_transformer.yaml
parakeet/models/transformerTTS/config/train_transformer.yaml
+1
-1
parakeet/models/transformerTTS/data.py
parakeet/models/transformerTTS/data.py
+29
-0
parakeet/models/transformerTTS/module.py
parakeet/models/transformerTTS/module.py
+3
-2
parakeet/models/transformerTTS/network.py
parakeet/models/transformerTTS/network.py
+9
-9
parakeet/models/transformerTTS/train_postnet.py
parakeet/models/transformerTTS/train_postnet.py
+31
-59
parakeet/models/transformerTTS/train_transformer.py
parakeet/models/transformerTTS/train_transformer.py
+41
-71
未找到文件。
parakeet/audio/__init__.py
0 → 100644
浏览文件 @
9fe6ad11
from
.audio
import
AudioProcessor
\ No newline at end of file
parakeet/audio/audio.py
0 → 100644
浏览文件 @
9fe6ad11
import
librosa
import
soundfile
as
sf
import
numpy
as
np
import
scipy.io
import
scipy.signal
class
AudioProcessor
(
object
):
def
__init__
(
self
,
sample_rate
=
None
,
# int, sampling rate
num_mels
=
None
,
# int, bands of mel spectrogram
min_level_db
=
None
,
# float, minimum level db
ref_level_db
=
None
,
# float, reference level dbn
n_fft
=
None
,
# int: number of samples in a frame for stft
win_length
=
None
,
# int: the same meaning with n_fft
hop_length
=
None
,
# int: number of samples between neighboring frame
power
=
None
,
# float:power to raise before griffin-lim
preemphasis
=
None
,
# float: preemphasis coefficident
signal_norm
=
None
,
#
symmetric_norm
=
False
,
# bool, apply clip norm in [-max_norm, max_form]
max_norm
=
None
,
# float, max norm
mel_fmin
=
None
,
# int: mel spectrogram's minimum frequency
mel_fmax
=
None
,
# int: mel spectrogram's maximum frequency
clip_norm
=
True
,
# bool: clip spectrogram's norm
griffin_lim_iters
=
None
,
# int:
do_trim_silence
=
False
,
# bool: trim silience
sound_norm
=
False
,
**
kwargs
):
self
.
sample_rate
=
sample_rate
self
.
num_mels
=
num_mels
self
.
min_level_db
=
min_level_db
self
.
ref_level_db
=
ref_level_db
# stft related
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
or
n_fft
# hop length defaults to 1/4 window_length
self
.
hop_length
=
hop_length
or
0.25
*
self
.
win_length
self
.
power
=
power
self
.
preemphasis
=
float
(
preemphasis
)
self
.
griffin_lim_iters
=
griffin_lim_iters
self
.
signal_norm
=
signal_norm
self
.
symmetric_norm
=
symmetric_norm
# mel transform related
self
.
mel_fmin
=
mel_fmin
self
.
mel_fmax
=
mel_fmax
self
.
max_norm
=
1.0
if
max_norm
is
None
else
float
(
max_norm
)
self
.
clip_norm
=
clip_norm
self
.
do_trim_silence
=
do_trim_silence
self
.
sound_norm
=
sound_norm
self
.
num_freq
,
self
.
frame_length_ms
,
self
.
frame_shift_ms
=
self
.
_stft_parameters
()
def
_stft_parameters
(
self
):
"""compute frame length and hop length in ms"""
frame_length_ms
=
self
.
win_length
*
1.
/
self
.
sample_rate
frame_shift_ms
=
self
.
hop_length
*
1.
/
self
.
sample_rate
num_freq
=
1
+
self
.
n_fft
//
2
return
num_freq
,
frame_length_ms
,
frame_shift_ms
def
__repr__
(
self
):
"""object repr"""
cls_name_str
=
self
.
__class__
.
__name__
members
=
vars
(
self
)
dict_str
=
"
\n
"
.
join
([
" {}: {},"
.
format
(
k
,
v
)
for
k
,
v
in
members
.
items
()])
repr_str
=
"{}(
\n
{})
\n
"
.
format
(
cls_name_str
,
dict_str
)
return
repr_str
def
save_wav
(
self
,
path
,
wav
):
"""save audio with scipy.io.wavfile in 16bit integers"""
wav_norm
=
wav
*
(
32767
/
max
(
0.01
,
np
.
max
(
np
.
abs
(
wav
))))
scipy
.
io
.
wavfile
.
write
(
path
,
self
.
sample_rate
,
wav_norm
.
as_type
(
np
.
int16
))
def
load_wav
(
self
,
path
,
sr
=
None
):
"""load wav -> trim_silence -> rescale"""
x
,
sr
=
librosa
.
load
(
path
,
sr
=
None
)
assert
self
.
sample_rate
==
sr
,
"audio sample rate: {}Hz != processor sample rate: {}Hz"
.
format
(
sr
,
self
.
sample_rate
)
if
self
.
do_trim_silence
:
try
:
x
=
self
.
trim_silence
(
x
)
except
ValueError
:
print
(
" [!] File cannot be trimmed for silence - {}"
.
format
(
path
))
if
self
.
sound_norm
:
x
=
x
/
x
.
max
()
*
0.9
# why 0.9 ?
return
x
def
trim_silence
(
self
,
wav
):
"""Trim soilent parts with a threshold and 0.01s margin"""
margin
=
int
(
self
.
sample_rate
*
0.01
)
wav
=
wav
[
margin
:
-
margin
]
trimed_wav
=
librosa
.
effects
.
trim
(
wav
,
top_db
=
60
,
frame_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
)[
0
]
return
trimed_wav
def
apply_preemphasis
(
self
,
x
):
if
self
.
preemphasis
==
0.
:
raise
RuntimeError
(
" !! Preemphasis coefficient should be positive. "
)
return
scipy
.
signal
.
lfilter
([
1.
,
-
self
.
preemphasis
],
[
1.
],
x
)
def
apply_inv_preemphasis
(
self
,
x
):
if
self
.
preemphasis
==
0.
:
raise
RuntimeError
(
" !! Preemphasis coefficient should be positive. "
)
return
scipy
.
signal
.
lfilter
([
1.
],
[
1.
,
-
self
.
preemphasis
],
x
)
def
_amplitude_to_db
(
self
,
x
):
amplitude_min
=
np
.
exp
(
self
.
min_level_db
/
20
*
np
.
log
(
10
))
return
20
*
np
.
log10
(
np
.
maximum
(
amplitude_min
,
x
))
@
staticmethod
def
_db_to_amplitude
(
x
):
return
np
.
power
(
10.
,
0.05
*
x
)
def
_linear_to_mel
(
self
,
spectrogram
):
_mel_basis
=
self
.
_build_mel_basis
()
return
np
.
dot
(
_mel_basis
,
spectrogram
)
def
_mel_to_linear
(
self
,
mel_spectrogram
):
inv_mel_basis
=
np
.
linalg
.
pinv
(
self
.
_build_mel_basis
())
return
np
.
maximum
(
1e-10
,
np
.
dot
(
inv_mel_basis
,
mel_spectrogram
))
def
_build_mel_basis
(
self
):
"""return mel basis for mel scale"""
if
self
.
mel_fmax
is
not
None
:
assert
self
.
mel_fmax
<=
self
.
sample_rate
//
2
return
librosa
.
filters
.
mel
(
self
.
sample_rate
,
self
.
n_fft
,
n_mels
=
self
.
num_mels
,
fmin
=
self
.
mel_fmin
,
fmax
=
self
.
mel_fmax
)
def
_normalize
(
self
,
S
):
"""put values in [0, self.max_norm] or [-self.max_norm, self,max_norm]"""
if
self
.
signal_norm
:
S_norm
=
(
S
-
self
.
min_level_db
)
/
(
-
self
.
min_level_db
)
if
self
.
symmetric_norm
:
S_norm
=
((
2
*
self
.
max_norm
)
*
S_norm
)
-
self
.
max_norm
if
self
.
clip_norm
:
S_norm
=
np
.
clip
(
S_norm
,
-
self
.
max_norm
,
self
.
max_norm
)
return
S_norm
else
:
S_norm
=
self
.
max_norm
*
S_norm
if
self
.
clip_norm
:
S_norm
=
np
.
clip
(
S_norm
,
0
,
self
.
max_norm
)
return
S_norm
else
:
return
S
def
_denormalize
(
self
,
S
):
"""denormalize values"""
S_denorm
=
S
if
self
.
signal_norm
:
if
self
.
symmetric_norm
:
if
self
.
clip_norm
:
S_denorm
=
np
.
clip
(
S_denorm
,
-
self
.
max_norm
,
self
.
max_norm
)
S_denorm
=
(
S_denorm
+
self
.
max_norm
)
*
(
-
self
.
min_level_db
)
/
(
2
*
self
.
max_norm
)
+
self
.
min_level_db
return
S_denorm
else
:
if
self
.
clip_norm
:
S_denorm
=
np
.
clip
(
S_denorm
,
0
,
self
.
max_norm
)
S_denorm
=
S_denorm
*
(
-
self
.
min_level_db
)
/
self
.
max_norm
+
self
.
min_level_db
return
S_denorm
else
:
return
S
def
_stft
(
self
,
y
):
return
librosa
.
stft
(
y
=
y
,
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
)
def
_istft
(
self
,
S
):
return
librosa
.
istft
(
S
,
hop_length
=
self
.
hop_length
,
win_length
=
self
.
win_length
)
def
spectrogram
(
self
,
y
):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if
self
.
preemphasis
:
D
=
self
.
_stft
(
self
.
apply_preemphasis
(
y
))
else
:
D
=
self
.
_stft
(
y
)
S
=
self
.
_amplitude_to_db
(
np
.
abs
(
D
))
-
self
.
ref_level_db
return
self
.
_normalize
(
S
)
def
melspectrogram
(
self
,
y
):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> mel_scale -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if
self
.
preemphasis
:
D
=
self
.
_stft
(
self
.
apply_preemphasis
(
y
))
else
:
D
=
self
.
_stft
(
y
)
S
=
self
.
_amplitude_to_db
(
self
.
_linear_to_mel
(
np
.
abs
(
D
)))
-
self
.
ref_level_db
return
self
.
_normalize
(
S
)
def
inv_spectrogram
(
self
,
spectrogram
):
"""convert spectrogram back to waveform using griffin_lim in librosa"""
S
=
self
.
_denormalize
(
spectrogram
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
if
self
.
preemphasis
:
return
self
.
apply_inv_preemphasis
(
self
.
_griffin_lim
(
S
**
self
.
power
))
return
self
.
_griffin_lim
(
S
**
self
.
power
)
def
inv_melspectrogram
(
self
,
mel_spectrogram
):
S
=
self
.
_denormalize
(
mel_spectrogram
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
S
=
self
.
_linear_to_mel
(
np
.
abs
(
S
))
if
self
.
preemphasis
:
return
self
.
apply_inv_preemphasis
(
self
.
_griffin_lim
(
S
**
self
.
power
))
return
self
.
_griffin_lim
(
S
**
self
.
power
)
def
out_linear_to_mel
(
self
,
linear_spec
):
"""convert output linear spec to mel spec"""
S
=
self
.
_denormalize
(
linear_spec
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
S
=
self
.
_linear_to_mel
(
np
.
abs
(
S
))
S
=
self
.
_amplitude_to_db
(
S
)
-
self
.
ref_level_db
mel
=
self
.
_normalize
(
S
)
return
mel
def
_griffin_lim
(
self
,
S
):
angles
=
np
.
exp
(
2j
*
np
.
pi
*
np
.
random
.
rand
(
*
S
.
shape
))
S_complex
=
np
.
abs
(
S
).
astype
(
np
.
complex
)
y
=
self
.
_istft
(
S_complex
*
angles
)
for
_
in
range
(
self
.
griffin_lim_iters
):
angles
=
np
.
exp
(
1j
*
np
.
angle
(
self
.
_stft
(
y
)))
y
=
self
.
_istft
(
S_complex
*
angles
)
return
y
@
staticmethod
def
mulaw_encode
(
wav
,
qc
):
mu
=
2
**
qc
-
1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal
=
np
.
sign
(
wav
)
*
np
.
log
(
1
+
mu
*
np
.
abs
(
wav
))
/
np
.
log
(
1.
+
mu
)
# Quantize signal to the specified number of levels.
signal
=
(
signal
+
1
)
/
2
*
mu
+
0.5
return
np
.
floor
(
signal
,)
@
staticmethod
def
mulaw_decode
(
wav
,
qc
):
"""Recovers waveform from quantized values."""
mu
=
2
**
qc
-
1
x
=
np
.
sign
(
wav
)
/
mu
*
((
1
+
mu
)
**
np
.
abs
(
wav
)
-
1
)
return
x
@
staticmethod
def
encode_16bits
(
x
):
return
np
.
clip
(
x
*
2
**
15
,
-
2
**
15
,
2
**
15
-
1
).
astype
(
np
.
int16
)
@
staticmethod
def
quantize
(
x
,
bits
):
return
(
x
+
1.
)
*
(
2
**
bits
-
1
)
/
2
@
staticmethod
def
dequantize
(
x
,
bits
):
return
2
*
x
/
(
2
**
bits
-
1
)
-
1
parakeet/data/datacargo.py
浏览文件 @
9fe6ad11
...
...
@@ -2,7 +2,8 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler
class
DataCargo
(
object
):
def
__init__
(
self
,
dataset
,
batch_size
=
1
,
sampler
=
None
,
shuffle
=
False
,
batch_sampler
=
None
,
drop_last
=
False
):
shuffle
=
False
,
batch_sampler
=
None
,
collate_fn
=
None
,
drop_last
=
False
):
self
.
dataset
=
dataset
if
batch_sampler
is
not
None
:
...
...
@@ -21,13 +22,20 @@ class DataCargo(object):
sampler
=
RandomSampler
(
dataset
)
else
:
sampler
=
SequentialSampler
(
dataset
)
# auto_collation without custom batch_sampler
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
else
:
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
self
.
batch_sampler
=
batch_sampler
if
collate_fn
is
None
:
collate_fn
=
dataset
.
_batch_examples
self
.
collate_fn
=
collate_fn
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
sampler
=
sampler
self
.
batch_sampler
=
batch_sampler
def
__iter__
(
self
):
return
DataIterator
(
self
)
...
...
@@ -57,6 +65,7 @@ class DataIterator(object):
self
.
_index_sampler
=
loader
.
_index_sampler
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
self
.
collate_fn
=
loader
.
collate_fn
def
__iter__
(
self
):
return
self
...
...
@@ -64,7 +73,7 @@ class DataIterator(object):
def
__next__
(
self
):
index
=
self
.
_next_index
()
# may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch
=
[
self
.
_dataset
[
i
]
for
i
in
index
]
# we can abstract it, too to use dynamic batch size
minibatch
=
self
.
_dataset
.
_batch_examples
(
minibatch
)
# list[Example] -> Batch
minibatch
=
self
.
collate_fn
(
minibatch
)
return
minibatch
def
_next_index
(
self
):
...
...
parakeet/models/transformerTTS/config/train_postnet.yaml
浏览文件 @
9fe6ad11
...
...
@@ -20,7 +20,7 @@ epochs: 10000
lr
:
0.001
save_step
:
500
use_gpu
:
True
use_data_parallel
:
Fals
e
use_data_parallel
:
Tru
e
data_path
:
../../../dataset/LJSpeech-1.1
save_path
:
./checkpoint
...
...
parakeet/models/transformerTTS/config/train_transformer.yaml
浏览文件 @
9fe6ad11
...
...
@@ -21,7 +21,7 @@ lr: 0.001
save_step
:
500
image_step
:
2000
use_gpu
:
True
use_data_parallel
:
Fals
e
use_data_parallel
:
Tru
e
data_path
:
../../../dataset/LJSpeech-1.1
save_path
:
./checkpoint
...
...
parakeet/models/transformerTTS/data.py
0 → 100644
浏览文件 @
9fe6ad11
from
pathlib
import
Path
import
numpy
as
np
from
paddle
import
fluid
from
parakeet.data.sampler
import
DistributedSampler
from
parakeet.data.datacargo
import
DataCargo
from
preprocess
import
batch_examples
,
LJSpeech
,
batch_examples_postnet
class
LJSpeechLoader
:
def
__init__
(
self
,
config
,
nranks
,
rank
,
is_postnet
=
False
):
place
=
fluid
.
CUDAPlace
(
rank
)
if
config
.
use_gpu
else
fluid
.
CPUPlace
()
LJSPEECH_ROOT
=
Path
(
config
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
sampler
=
DistributedSampler
(
len
(
dataset
),
nranks
,
rank
)
assert
config
.
batch_size
%
nranks
==
0
each_bs
=
config
.
batch_size
//
nranks
if
is_postnet
:
dataloader
=
DataCargo
(
dataset
,
sampler
=
sampler
,
batch_size
=
each_bs
,
shuffle
=
True
,
collate_fn
=
batch_examples_postnet
,
drop_last
=
True
)
else
:
dataloader
=
DataCargo
(
dataset
,
sampler
=
sampler
,
batch_size
=
each_bs
,
shuffle
=
True
,
collate_fn
=
batch_examples
,
drop_last
=
True
)
self
.
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
self
.
reader
.
set_batch_generator
(
dataloader
,
place
)
parakeet/models/transformerTTS/module.py
浏览文件 @
9fe6ad11
...
...
@@ -130,7 +130,7 @@ class EncoderPrenet(dg.Layer):
self
.
projection
=
FC
(
self
.
full_name
(),
num_hidden
,
num_hidden
)
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
fluid
.
layers
.
unsqueeze
(
x
,
axes
=
[
-
1
])
)
#(batch_size, seq_len, embending_size)
x
=
self
.
embedding
(
x
)
#(batch_size, seq_len, embending_size)
x
=
layers
.
transpose
(
x
,[
0
,
2
,
1
])
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm1
(
self
.
conv1
(
x
))),
0.2
)
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm2
(
self
.
conv2
(
x
))),
0.2
)
...
...
@@ -211,9 +211,10 @@ class ScaledDotProductAttention(dg.Layer):
# Mask key to ignore padding
if
mask
is
not
None
:
attention
=
attention
*
mask
mask
=
(
mask
==
0
).
astype
(
float
)
*
(
-
2
**
32
+
1
)
mask
=
(
mask
==
0
).
astype
(
np
.
float32
)
*
(
-
2
**
32
+
1
)
attention
=
attention
+
mask
attention
=
layers
.
softmax
(
attention
)
# Mask query to ignore padding
# Not sure how to work
...
...
parakeet/models/transformerTTS/network.py
浏览文件 @
9fe6ad11
...
...
@@ -7,9 +7,9 @@ class Encoder(dg.Layer):
def
__init__
(
self
,
name_scope
,
embedding_size
,
num_hidden
,
config
):
super
(
Encoder
,
self
).
__init__
(
name_scope
)
self
.
num_hidden
=
num_hidden
param
=
fluid
.
ParamAttr
(
name
=
'alpha'
)
self
.
alpha
=
self
.
create_parameter
(
param
,
shape
=
(
1
,
),
dtype
=
'float32'
,
default_initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
1.0
)
)
param
=
fluid
.
ParamAttr
(
name
=
'alpha'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
))
self
.
alpha
=
self
.
create_parameter
(
param
,
shape
=
(
1
,
),
dtype
=
'float32'
)
self
.
pos_inp
=
get_sinusoid_encoding_table
(
1024
,
self
.
num_hidden
,
padding_idx
=
0
)
self
.
pos_emb
=
dg
.
Embedding
(
name_scope
=
self
.
full_name
(),
size
=
[
1024
,
num_hidden
],
...
...
@@ -31,8 +31,8 @@ class Encoder(dg.Layer):
def
forward
(
self
,
x
,
positional
):
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
query_mask
=
(
positional
!=
0
).
astype
(
float
)
mask
=
(
positional
!=
0
).
astype
(
float
)
query_mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
(
mask
,[
1
]),
[
1
,
x
.
shape
[
1
],
1
])
else
:
query_mask
,
mask
=
None
,
None
...
...
@@ -42,7 +42,7 @@ class Encoder(dg.Layer):
# Get positional encoding
positional
=
self
.
pos_emb
(
fluid
.
layers
.
unsqueeze
(
positional
,
axes
=
[
-
1
])
)
positional
=
self
.
pos_emb
(
positional
)
x
=
positional
*
self
.
alpha
+
x
#(N, T, C)
...
...
@@ -102,14 +102,14 @@ class Decoder(dg.Layer):
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
#zeros = np.zeros(positional.shape, dtype=np.float32)
m_mask
=
(
positional
!=
0
).
astype
(
float
)
m_mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
np
.
repeat
(
np
.
expand_dims
(
m_mask
.
numpy
()
==
0
,
axis
=
1
),
decoder_len
,
axis
=
1
)
mask
=
mask
+
np
.
repeat
(
np
.
expand_dims
(
np
.
triu
(
np
.
ones
([
decoder_len
,
decoder_len
]),
1
),
axis
=
0
)
,
batch_size
,
axis
=
0
)
mask
=
fluid
.
layers
.
cast
(
dg
.
to_variable
(
mask
==
0
),
np
.
float32
)
# (batch_size, decoder_len, decoder_len)
zero_mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
((
c_mask
!=
0
).
astype
(
float
),
axes
=
2
),
[
1
,
1
,
decoder_len
])
zero_mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
((
c_mask
!=
0
).
astype
(
np
.
float32
),
axes
=
2
),
[
1
,
1
,
decoder_len
])
# (batch_size, decoder_len, seq_len)
zero_mask
=
fluid
.
layers
.
transpose
(
zero_mask
,
[
0
,
2
,
1
])
...
...
@@ -125,7 +125,7 @@ class Decoder(dg.Layer):
query
=
self
.
linear
(
query
)
# Get position embedding
positional
=
self
.
pos_emb
(
fluid
.
layers
.
unsqueeze
(
positional
,
axes
=
[
-
1
])
)
positional
=
self
.
pos_emb
(
positional
)
query
=
positional
*
self
.
alpha
+
query
#positional dropout
...
...
parakeet/models/transformerTTS/train_postnet.py
浏览文件 @
9fe6ad11
from
network
import
*
from
preprocess
import
batch_examples_postnet
,
LJSpeech
from
tensorboardX
import
SummaryWriter
import
os
from
tqdm
import
tqdm
from
parakeet.data.datacargo
import
DataCargo
from
pathlib
import
Path
import
jsonargparse
from
parse
import
add_config_options_to_parser
from
pprint
import
pprint
from
data
import
LJSpeechLoader
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
"""
...
...
@@ -27,21 +26,15 @@ class MyDataParallel(dg.parallel.DataParallel):
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
def
main
():
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train postnet model"
,
formatter_class
=
'default_argparse'
)
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_postnet.yaml'
.
split
())
def
main
(
cfg
):
local_rank
=
dg
.
parallel
.
Env
().
local_rank
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
cfg
.
use_data_parallel
else
0
nranks
=
dg
.
parallel
.
Env
().
nranks
if
cfg
.
use_data_parallel
else
1
if
local_rank
==
0
:
# Print the whole config setting.
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
LJSPEECH_ROOT
=
Path
(
cfg
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
dataloader
=
DataCargo
(
dataset
,
batch_size
=
cfg
.
batch_size
,
shuffle
=
True
,
collate_fn
=
batch_examples_postnet
,
drop_last
=
True
)
global_step
=
0
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
...
...
@@ -50,35 +43,10 @@ def main():
if
not
os
.
path
.
exists
(
cfg
.
log_dir
):
os
.
mkdir
(
cfg
.
log_dir
)
path
=
os
.
path
.
join
(
cfg
.
log_dir
,
'postnet'
)
writer
=
SummaryWriter
(
path
)
with
dg
.
guard
(
place
):
# dataloader
input_fields
=
{
'names'
:
[
'mel'
,
'mag'
],
'shapes'
:
[[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
None
,
257
]],
'dtypes'
:
[
'float32'
,
'float32'
],
'lod_levels'
:
[
0
,
0
]
}
inputs
=
[
fluid
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))
]
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
inputs
,
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
writer
=
SummaryWriter
(
path
)
if
local_rank
==
0
else
None
with
dg
.
guard
(
place
):
model
=
ModelPostNet
(
'postnet'
,
cfg
)
model
.
train
()
...
...
@@ -94,9 +62,10 @@ def main():
strategy
=
dg
.
parallel
.
prepare_context
()
model
=
MyDataParallel
(
model
,
strategy
)
reader
=
LJSpeechLoader
(
cfg
,
nranks
,
local_rank
,
is_postnet
=
True
).
reader
()
for
epoch
in
range
(
cfg
.
epochs
):
reader
.
set_batch_generator
(
dataloader
,
place
)
pbar
=
tqdm
(
reader
())
pbar
=
tqdm
(
reader
)
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
mel
,
mag
=
data
...
...
@@ -109,17 +78,18 @@ def main():
loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
mag_pred
,
mag
)))
if
cfg
.
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
writer
.
add_scalars
(
'training_loss'
,{
'loss'
:
loss
.
numpy
(),
},
global_step
)
loss
.
backward
()
if
cfg
.
use_data_parallel
:
model
.
apply_collective_grads
()
else
:
loss
.
backward
()
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
model
.
clear_gradients
()
if
local_rank
==
0
:
writer
.
add_scalars
(
'training_loss'
,{
'loss'
:
loss
.
numpy
(),
},
global_step
)
if
global_step
%
cfg
.
save_step
==
0
:
if
not
os
.
path
.
exists
(
cfg
.
save_path
):
os
.
mkdir
(
cfg
.
save_path
)
...
...
@@ -127,9 +97,11 @@ def main():
dg
.
save_dygraph
(
model
.
state_dict
(),
save_path
)
dg
.
save_dygraph
(
optimizer
.
state_dict
(),
save_path
)
if
local_rank
==
0
:
writer
.
close
()
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train postnet model"
,
formatter_class
=
'default_argparse'
)
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_postnet.yaml'
.
split
())
main
(
cfg
)
\ No newline at end of file
parakeet/models/transformerTTS/train_transformer.py
浏览文件 @
9fe6ad11
from
preprocess
import
batch_examples
,
LJSpeech
import
os
from
tqdm
import
tqdm
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.layers
as
layers
from
network
import
*
from
tensorboardX
import
SummaryWriter
from
parakeet.data.datacargo
import
DataCargo
from
pathlib
import
Path
import
jsonargparse
from
parse
import
add_config_options_to_parser
from
pprint
import
pprint
from
matplotlib
import
cm
from
data
import
LJSpeechLoader
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
"""
...
...
@@ -30,21 +29,14 @@ class MyDataParallel(dg.parallel.DataParallel):
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
def
main
():
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train TransformerTTS model"
,
formatter_class
=
'default_argparse'
)
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_transformer.yaml'
.
split
())
local_rank
=
dg
.
parallel
.
Env
().
local_rank
def
main
(
cfg
):
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
cfg
.
use_data_parallel
else
0
nranks
=
dg
.
parallel
.
Env
().
nranks
if
cfg
.
use_data_parallel
else
1
if
local_rank
==
0
:
# Print the whole config setting.
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
LJSPEECH_ROOT
=
Path
(
cfg
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
dataloader
=
DataCargo
(
dataset
,
batch_size
=
cfg
.
batch_size
,
shuffle
=
True
,
collate_fn
=
batch_examples
,
drop_last
=
True
)
global_step
=
0
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
...
...
@@ -57,39 +49,13 @@ def main():
writer
=
SummaryWriter
(
path
)
if
local_rank
==
0
else
None
with
dg
.
guard
(
place
):
if
cfg
.
use_data_parallel
:
strategy
=
dg
.
parallel
.
prepare_context
()
# dataloader
input_fields
=
{
'names'
:
[
'character'
,
'mel'
,
'mel_input'
,
'pos_text'
,
'pos_mel'
,
'text_len'
],
'shapes'
:
[[
cfg
.
batch_size
,
None
],
[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
1
],
[
cfg
.
batch_size
,
1
],
[
cfg
.
batch_size
,
1
]],
'dtypes'
:
[
'float32'
,
'float32'
,
'float32'
,
'int64'
,
'int64'
,
'int64'
],
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
,
0
]
}
inputs
=
[
fluid
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))
]
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
inputs
,
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
model
=
Model
(
'transtts'
,
cfg
)
model
.
train
()
optimizer
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
dg
.
NoamDecay
(
1
/
(
4000
*
(
cfg
.
lr
**
2
)),
4000
))
reader
=
LJSpeechLoader
(
cfg
,
nranks
,
local_rank
).
reader
()
if
cfg
.
checkpoint_path
is
not
None
:
model_dict
,
opti_dict
=
fluid
.
dygraph
.
load_dygraph
(
cfg
.
checkpoint_path
)
model
.
set_dict
(
model_dict
)
...
...
@@ -97,11 +63,11 @@ def main():
print
(
"load checkpoint!!!"
)
if
cfg
.
use_data_parallel
:
strategy
=
dg
.
parallel
.
prepare_context
()
model
=
MyDataParallel
(
model
,
strategy
)
for
epoch
in
range
(
cfg
.
epochs
):
reader
.
set_batch_generator
(
dataloader
,
place
)
pbar
=
tqdm
(
reader
())
pbar
=
tqdm
(
reader
)
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
character
,
mel
,
mel_input
,
pos_text
,
pos_mel
,
text_length
=
data
...
...
@@ -114,9 +80,7 @@ def main():
post_mel_loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
postnet_pred
,
mel
)))
loss
=
mel_loss
+
post_mel_loss
if
cfg
.
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
if
local_rank
==
0
:
writer
.
add_scalars
(
'training_loss'
,
{
'mel_loss'
:
mel_loss
.
numpy
(),
'post_mel_loss'
:
post_mel_loss
.
numpy
(),
...
...
@@ -145,9 +109,12 @@ def main():
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
16
])
*
255
)
writer
.
add_image
(
'Attention_dec_%d_0'
%
global_step
,
x
,
i
*
4
+
j
,
dataformats
=
"HWC"
)
loss
.
backward
()
if
cfg
.
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
loss
.
backward
()
model
.
apply_collective_grads
()
else
:
loss
.
backward
()
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
model
.
clear_gradients
()
...
...
@@ -163,4 +130,7 @@ def main():
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train TransformerTTS model"
,
formatter_class
=
'default_argparse'
)
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_transformer.yaml'
.
split
())
main
(
cfg
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录