Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
9fe6ad11
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9fe6ad11
编写于
12月 17, 2019
作者:
L
lifuchen
提交者:
chenfeiyu
12月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Training with multi-GPU
上级
8a9bbc26
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
390 addition
and
147 deletion
+390
-147
parakeet/audio/__init__.py
parakeet/audio/__init__.py
+1
-0
parakeet/audio/audio.py
parakeet/audio/audio.py
+261
-0
parakeet/data/datacargo.py
parakeet/data/datacargo.py
+13
-4
parakeet/models/transformerTTS/config/train_postnet.yaml
parakeet/models/transformerTTS/config/train_postnet.yaml
+1
-1
parakeet/models/transformerTTS/config/train_transformer.yaml
parakeet/models/transformerTTS/config/train_transformer.yaml
+1
-1
parakeet/models/transformerTTS/data.py
parakeet/models/transformerTTS/data.py
+29
-0
parakeet/models/transformerTTS/module.py
parakeet/models/transformerTTS/module.py
+3
-2
parakeet/models/transformerTTS/network.py
parakeet/models/transformerTTS/network.py
+9
-9
parakeet/models/transformerTTS/train_postnet.py
parakeet/models/transformerTTS/train_postnet.py
+31
-59
parakeet/models/transformerTTS/train_transformer.py
parakeet/models/transformerTTS/train_transformer.py
+41
-71
未找到文件。
parakeet/audio/__init__.py
0 → 100644
浏览文件 @
9fe6ad11
from
.audio
import
AudioProcessor
\ No newline at end of file
parakeet/audio/audio.py
0 → 100644
浏览文件 @
9fe6ad11
import
librosa
import
soundfile
as
sf
import
numpy
as
np
import
scipy.io
import
scipy.signal
class
AudioProcessor
(
object
):
def
__init__
(
self
,
sample_rate
=
None
,
# int, sampling rate
num_mels
=
None
,
# int, bands of mel spectrogram
min_level_db
=
None
,
# float, minimum level db
ref_level_db
=
None
,
# float, reference level dbn
n_fft
=
None
,
# int: number of samples in a frame for stft
win_length
=
None
,
# int: the same meaning with n_fft
hop_length
=
None
,
# int: number of samples between neighboring frame
power
=
None
,
# float:power to raise before griffin-lim
preemphasis
=
None
,
# float: preemphasis coefficident
signal_norm
=
None
,
#
symmetric_norm
=
False
,
# bool, apply clip norm in [-max_norm, max_form]
max_norm
=
None
,
# float, max norm
mel_fmin
=
None
,
# int: mel spectrogram's minimum frequency
mel_fmax
=
None
,
# int: mel spectrogram's maximum frequency
clip_norm
=
True
,
# bool: clip spectrogram's norm
griffin_lim_iters
=
None
,
# int:
do_trim_silence
=
False
,
# bool: trim silience
sound_norm
=
False
,
**
kwargs
):
self
.
sample_rate
=
sample_rate
self
.
num_mels
=
num_mels
self
.
min_level_db
=
min_level_db
self
.
ref_level_db
=
ref_level_db
# stft related
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
or
n_fft
# hop length defaults to 1/4 window_length
self
.
hop_length
=
hop_length
or
0.25
*
self
.
win_length
self
.
power
=
power
self
.
preemphasis
=
float
(
preemphasis
)
self
.
griffin_lim_iters
=
griffin_lim_iters
self
.
signal_norm
=
signal_norm
self
.
symmetric_norm
=
symmetric_norm
# mel transform related
self
.
mel_fmin
=
mel_fmin
self
.
mel_fmax
=
mel_fmax
self
.
max_norm
=
1.0
if
max_norm
is
None
else
float
(
max_norm
)
self
.
clip_norm
=
clip_norm
self
.
do_trim_silence
=
do_trim_silence
self
.
sound_norm
=
sound_norm
self
.
num_freq
,
self
.
frame_length_ms
,
self
.
frame_shift_ms
=
self
.
_stft_parameters
()
def
_stft_parameters
(
self
):
"""compute frame length and hop length in ms"""
frame_length_ms
=
self
.
win_length
*
1.
/
self
.
sample_rate
frame_shift_ms
=
self
.
hop_length
*
1.
/
self
.
sample_rate
num_freq
=
1
+
self
.
n_fft
//
2
return
num_freq
,
frame_length_ms
,
frame_shift_ms
def
__repr__
(
self
):
"""object repr"""
cls_name_str
=
self
.
__class__
.
__name__
members
=
vars
(
self
)
dict_str
=
"
\n
"
.
join
([
" {}: {},"
.
format
(
k
,
v
)
for
k
,
v
in
members
.
items
()])
repr_str
=
"{}(
\n
{})
\n
"
.
format
(
cls_name_str
,
dict_str
)
return
repr_str
def
save_wav
(
self
,
path
,
wav
):
"""save audio with scipy.io.wavfile in 16bit integers"""
wav_norm
=
wav
*
(
32767
/
max
(
0.01
,
np
.
max
(
np
.
abs
(
wav
))))
scipy
.
io
.
wavfile
.
write
(
path
,
self
.
sample_rate
,
wav_norm
.
as_type
(
np
.
int16
))
def
load_wav
(
self
,
path
,
sr
=
None
):
"""load wav -> trim_silence -> rescale"""
x
,
sr
=
librosa
.
load
(
path
,
sr
=
None
)
assert
self
.
sample_rate
==
sr
,
"audio sample rate: {}Hz != processor sample rate: {}Hz"
.
format
(
sr
,
self
.
sample_rate
)
if
self
.
do_trim_silence
:
try
:
x
=
self
.
trim_silence
(
x
)
except
ValueError
:
print
(
" [!] File cannot be trimmed for silence - {}"
.
format
(
path
))
if
self
.
sound_norm
:
x
=
x
/
x
.
max
()
*
0.9
# why 0.9 ?
return
x
def
trim_silence
(
self
,
wav
):
"""Trim soilent parts with a threshold and 0.01s margin"""
margin
=
int
(
self
.
sample_rate
*
0.01
)
wav
=
wav
[
margin
:
-
margin
]
trimed_wav
=
librosa
.
effects
.
trim
(
wav
,
top_db
=
60
,
frame_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
)[
0
]
return
trimed_wav
def
apply_preemphasis
(
self
,
x
):
if
self
.
preemphasis
==
0.
:
raise
RuntimeError
(
" !! Preemphasis coefficient should be positive. "
)
return
scipy
.
signal
.
lfilter
([
1.
,
-
self
.
preemphasis
],
[
1.
],
x
)
def
apply_inv_preemphasis
(
self
,
x
):
if
self
.
preemphasis
==
0.
:
raise
RuntimeError
(
" !! Preemphasis coefficient should be positive. "
)
return
scipy
.
signal
.
lfilter
([
1.
],
[
1.
,
-
self
.
preemphasis
],
x
)
def
_amplitude_to_db
(
self
,
x
):
amplitude_min
=
np
.
exp
(
self
.
min_level_db
/
20
*
np
.
log
(
10
))
return
20
*
np
.
log10
(
np
.
maximum
(
amplitude_min
,
x
))
@
staticmethod
def
_db_to_amplitude
(
x
):
return
np
.
power
(
10.
,
0.05
*
x
)
def
_linear_to_mel
(
self
,
spectrogram
):
_mel_basis
=
self
.
_build_mel_basis
()
return
np
.
dot
(
_mel_basis
,
spectrogram
)
def
_mel_to_linear
(
self
,
mel_spectrogram
):
inv_mel_basis
=
np
.
linalg
.
pinv
(
self
.
_build_mel_basis
())
return
np
.
maximum
(
1e-10
,
np
.
dot
(
inv_mel_basis
,
mel_spectrogram
))
def
_build_mel_basis
(
self
):
"""return mel basis for mel scale"""
if
self
.
mel_fmax
is
not
None
:
assert
self
.
mel_fmax
<=
self
.
sample_rate
//
2
return
librosa
.
filters
.
mel
(
self
.
sample_rate
,
self
.
n_fft
,
n_mels
=
self
.
num_mels
,
fmin
=
self
.
mel_fmin
,
fmax
=
self
.
mel_fmax
)
def
_normalize
(
self
,
S
):
"""put values in [0, self.max_norm] or [-self.max_norm, self,max_norm]"""
if
self
.
signal_norm
:
S_norm
=
(
S
-
self
.
min_level_db
)
/
(
-
self
.
min_level_db
)
if
self
.
symmetric_norm
:
S_norm
=
((
2
*
self
.
max_norm
)
*
S_norm
)
-
self
.
max_norm
if
self
.
clip_norm
:
S_norm
=
np
.
clip
(
S_norm
,
-
self
.
max_norm
,
self
.
max_norm
)
return
S_norm
else
:
S_norm
=
self
.
max_norm
*
S_norm
if
self
.
clip_norm
:
S_norm
=
np
.
clip
(
S_norm
,
0
,
self
.
max_norm
)
return
S_norm
else
:
return
S
def
_denormalize
(
self
,
S
):
"""denormalize values"""
S_denorm
=
S
if
self
.
signal_norm
:
if
self
.
symmetric_norm
:
if
self
.
clip_norm
:
S_denorm
=
np
.
clip
(
S_denorm
,
-
self
.
max_norm
,
self
.
max_norm
)
S_denorm
=
(
S_denorm
+
self
.
max_norm
)
*
(
-
self
.
min_level_db
)
/
(
2
*
self
.
max_norm
)
+
self
.
min_level_db
return
S_denorm
else
:
if
self
.
clip_norm
:
S_denorm
=
np
.
clip
(
S_denorm
,
0
,
self
.
max_norm
)
S_denorm
=
S_denorm
*
(
-
self
.
min_level_db
)
/
self
.
max_norm
+
self
.
min_level_db
return
S_denorm
else
:
return
S
def
_stft
(
self
,
y
):
return
librosa
.
stft
(
y
=
y
,
n_fft
=
self
.
n_fft
,
win_length
=
self
.
win_length
,
hop_length
=
self
.
hop_length
)
def
_istft
(
self
,
S
):
return
librosa
.
istft
(
S
,
hop_length
=
self
.
hop_length
,
win_length
=
self
.
win_length
)
def
spectrogram
(
self
,
y
):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if
self
.
preemphasis
:
D
=
self
.
_stft
(
self
.
apply_preemphasis
(
y
))
else
:
D
=
self
.
_stft
(
y
)
S
=
self
.
_amplitude_to_db
(
np
.
abs
(
D
))
-
self
.
ref_level_db
return
self
.
_normalize
(
S
)
def
melspectrogram
(
self
,
y
):
"""compute linear spectrogram(amplitude)
preemphasis -> stft -> mag -> mel_scale -> amplitude_to_db -> minus_ref_level_db -> normalize
"""
if
self
.
preemphasis
:
D
=
self
.
_stft
(
self
.
apply_preemphasis
(
y
))
else
:
D
=
self
.
_stft
(
y
)
S
=
self
.
_amplitude_to_db
(
self
.
_linear_to_mel
(
np
.
abs
(
D
)))
-
self
.
ref_level_db
return
self
.
_normalize
(
S
)
def
inv_spectrogram
(
self
,
spectrogram
):
"""convert spectrogram back to waveform using griffin_lim in librosa"""
S
=
self
.
_denormalize
(
spectrogram
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
if
self
.
preemphasis
:
return
self
.
apply_inv_preemphasis
(
self
.
_griffin_lim
(
S
**
self
.
power
))
return
self
.
_griffin_lim
(
S
**
self
.
power
)
def
inv_melspectrogram
(
self
,
mel_spectrogram
):
S
=
self
.
_denormalize
(
mel_spectrogram
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
S
=
self
.
_linear_to_mel
(
np
.
abs
(
S
))
if
self
.
preemphasis
:
return
self
.
apply_inv_preemphasis
(
self
.
_griffin_lim
(
S
**
self
.
power
))
return
self
.
_griffin_lim
(
S
**
self
.
power
)
def
out_linear_to_mel
(
self
,
linear_spec
):
"""convert output linear spec to mel spec"""
S
=
self
.
_denormalize
(
linear_spec
)
S
=
self
.
_db_to_amplitude
(
S
+
self
.
ref_level_db
)
S
=
self
.
_linear_to_mel
(
np
.
abs
(
S
))
S
=
self
.
_amplitude_to_db
(
S
)
-
self
.
ref_level_db
mel
=
self
.
_normalize
(
S
)
return
mel
def
_griffin_lim
(
self
,
S
):
angles
=
np
.
exp
(
2j
*
np
.
pi
*
np
.
random
.
rand
(
*
S
.
shape
))
S_complex
=
np
.
abs
(
S
).
astype
(
np
.
complex
)
y
=
self
.
_istft
(
S_complex
*
angles
)
for
_
in
range
(
self
.
griffin_lim_iters
):
angles
=
np
.
exp
(
1j
*
np
.
angle
(
self
.
_stft
(
y
)))
y
=
self
.
_istft
(
S_complex
*
angles
)
return
y
@
staticmethod
def
mulaw_encode
(
wav
,
qc
):
mu
=
2
**
qc
-
1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal
=
np
.
sign
(
wav
)
*
np
.
log
(
1
+
mu
*
np
.
abs
(
wav
))
/
np
.
log
(
1.
+
mu
)
# Quantize signal to the specified number of levels.
signal
=
(
signal
+
1
)
/
2
*
mu
+
0.5
return
np
.
floor
(
signal
,)
@
staticmethod
def
mulaw_decode
(
wav
,
qc
):
"""Recovers waveform from quantized values."""
mu
=
2
**
qc
-
1
x
=
np
.
sign
(
wav
)
/
mu
*
((
1
+
mu
)
**
np
.
abs
(
wav
)
-
1
)
return
x
@
staticmethod
def
encode_16bits
(
x
):
return
np
.
clip
(
x
*
2
**
15
,
-
2
**
15
,
2
**
15
-
1
).
astype
(
np
.
int16
)
@
staticmethod
def
quantize
(
x
,
bits
):
return
(
x
+
1.
)
*
(
2
**
bits
-
1
)
/
2
@
staticmethod
def
dequantize
(
x
,
bits
):
return
2
*
x
/
(
2
**
bits
-
1
)
-
1
parakeet/data/datacargo.py
浏览文件 @
9fe6ad11
...
@@ -2,7 +2,8 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler
...
@@ -2,7 +2,8 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler
class
DataCargo
(
object
):
class
DataCargo
(
object
):
def
__init__
(
self
,
dataset
,
batch_size
=
1
,
sampler
=
None
,
def
__init__
(
self
,
dataset
,
batch_size
=
1
,
sampler
=
None
,
shuffle
=
False
,
batch_sampler
=
None
,
drop_last
=
False
):
shuffle
=
False
,
batch_sampler
=
None
,
collate_fn
=
None
,
drop_last
=
False
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
if
batch_sampler
is
not
None
:
if
batch_sampler
is
not
None
:
...
@@ -21,13 +22,20 @@ class DataCargo(object):
...
@@ -21,13 +22,20 @@ class DataCargo(object):
sampler
=
RandomSampler
(
dataset
)
sampler
=
RandomSampler
(
dataset
)
else
:
else
:
sampler
=
SequentialSampler
(
dataset
)
sampler
=
SequentialSampler
(
dataset
)
# auto_collation without custom batch_sampler
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
else
:
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
self
.
batch_sampler
=
batch_sampler
if
collate_fn
is
None
:
collate_fn
=
dataset
.
_batch_examples
self
.
collate_fn
=
collate_fn
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
drop_last
=
drop_last
self
.
sampler
=
sampler
self
.
sampler
=
sampler
self
.
batch_sampler
=
batch_sampler
def
__iter__
(
self
):
def
__iter__
(
self
):
return
DataIterator
(
self
)
return
DataIterator
(
self
)
...
@@ -57,6 +65,7 @@ class DataIterator(object):
...
@@ -57,6 +65,7 @@ class DataIterator(object):
self
.
_index_sampler
=
loader
.
_index_sampler
self
.
_index_sampler
=
loader
.
_index_sampler
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
self
.
collate_fn
=
loader
.
collate_fn
def
__iter__
(
self
):
def
__iter__
(
self
):
return
self
return
self
...
@@ -64,7 +73,7 @@ class DataIterator(object):
...
@@ -64,7 +73,7 @@ class DataIterator(object):
def
__next__
(
self
):
def
__next__
(
self
):
index
=
self
.
_next_index
()
# may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
index
=
self
.
_next_index
()
# may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch
=
[
self
.
_dataset
[
i
]
for
i
in
index
]
# we can abstract it, too to use dynamic batch size
minibatch
=
[
self
.
_dataset
[
i
]
for
i
in
index
]
# we can abstract it, too to use dynamic batch size
minibatch
=
self
.
_dataset
.
_batch_examples
(
minibatch
)
# list[Example] -> Batch
minibatch
=
self
.
collate_fn
(
minibatch
)
return
minibatch
return
minibatch
def
_next_index
(
self
):
def
_next_index
(
self
):
...
...
parakeet/models/transformerTTS/config/train_postnet.yaml
浏览文件 @
9fe6ad11
...
@@ -20,7 +20,7 @@ epochs: 10000
...
@@ -20,7 +20,7 @@ epochs: 10000
lr
:
0.001
lr
:
0.001
save_step
:
500
save_step
:
500
use_gpu
:
True
use_gpu
:
True
use_data_parallel
:
Fals
e
use_data_parallel
:
Tru
e
data_path
:
../../../dataset/LJSpeech-1.1
data_path
:
../../../dataset/LJSpeech-1.1
save_path
:
./checkpoint
save_path
:
./checkpoint
...
...
parakeet/models/transformerTTS/config/train_transformer.yaml
浏览文件 @
9fe6ad11
...
@@ -21,7 +21,7 @@ lr: 0.001
...
@@ -21,7 +21,7 @@ lr: 0.001
save_step
:
500
save_step
:
500
image_step
:
2000
image_step
:
2000
use_gpu
:
True
use_gpu
:
True
use_data_parallel
:
Fals
e
use_data_parallel
:
Tru
e
data_path
:
../../../dataset/LJSpeech-1.1
data_path
:
../../../dataset/LJSpeech-1.1
save_path
:
./checkpoint
save_path
:
./checkpoint
...
...
parakeet/models/transformerTTS/data.py
0 → 100644
浏览文件 @
9fe6ad11
from
pathlib
import
Path
import
numpy
as
np
from
paddle
import
fluid
from
parakeet.data.sampler
import
DistributedSampler
from
parakeet.data.datacargo
import
DataCargo
from
preprocess
import
batch_examples
,
LJSpeech
,
batch_examples_postnet
class
LJSpeechLoader
:
def
__init__
(
self
,
config
,
nranks
,
rank
,
is_postnet
=
False
):
place
=
fluid
.
CUDAPlace
(
rank
)
if
config
.
use_gpu
else
fluid
.
CPUPlace
()
LJSPEECH_ROOT
=
Path
(
config
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
sampler
=
DistributedSampler
(
len
(
dataset
),
nranks
,
rank
)
assert
config
.
batch_size
%
nranks
==
0
each_bs
=
config
.
batch_size
//
nranks
if
is_postnet
:
dataloader
=
DataCargo
(
dataset
,
sampler
=
sampler
,
batch_size
=
each_bs
,
shuffle
=
True
,
collate_fn
=
batch_examples_postnet
,
drop_last
=
True
)
else
:
dataloader
=
DataCargo
(
dataset
,
sampler
=
sampler
,
batch_size
=
each_bs
,
shuffle
=
True
,
collate_fn
=
batch_examples
,
drop_last
=
True
)
self
.
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
self
.
reader
.
set_batch_generator
(
dataloader
,
place
)
parakeet/models/transformerTTS/module.py
浏览文件 @
9fe6ad11
...
@@ -130,7 +130,7 @@ class EncoderPrenet(dg.Layer):
...
@@ -130,7 +130,7 @@ class EncoderPrenet(dg.Layer):
self
.
projection
=
FC
(
self
.
full_name
(),
num_hidden
,
num_hidden
)
self
.
projection
=
FC
(
self
.
full_name
(),
num_hidden
,
num_hidden
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
embedding
(
fluid
.
layers
.
unsqueeze
(
x
,
axes
=
[
-
1
])
)
#(batch_size, seq_len, embending_size)
x
=
self
.
embedding
(
x
)
#(batch_size, seq_len, embending_size)
x
=
layers
.
transpose
(
x
,[
0
,
2
,
1
])
x
=
layers
.
transpose
(
x
,[
0
,
2
,
1
])
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm1
(
self
.
conv1
(
x
))),
0.2
)
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm1
(
self
.
conv1
(
x
))),
0.2
)
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm2
(
self
.
conv2
(
x
))),
0.2
)
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
batch_norm2
(
self
.
conv2
(
x
))),
0.2
)
...
@@ -211,9 +211,10 @@ class ScaledDotProductAttention(dg.Layer):
...
@@ -211,9 +211,10 @@ class ScaledDotProductAttention(dg.Layer):
# Mask key to ignore padding
# Mask key to ignore padding
if
mask
is
not
None
:
if
mask
is
not
None
:
attention
=
attention
*
mask
attention
=
attention
*
mask
mask
=
(
mask
==
0
).
astype
(
float
)
*
(
-
2
**
32
+
1
)
mask
=
(
mask
==
0
).
astype
(
np
.
float32
)
*
(
-
2
**
32
+
1
)
attention
=
attention
+
mask
attention
=
attention
+
mask
attention
=
layers
.
softmax
(
attention
)
attention
=
layers
.
softmax
(
attention
)
# Mask query to ignore padding
# Mask query to ignore padding
# Not sure how to work
# Not sure how to work
...
...
parakeet/models/transformerTTS/network.py
浏览文件 @
9fe6ad11
...
@@ -7,9 +7,9 @@ class Encoder(dg.Layer):
...
@@ -7,9 +7,9 @@ class Encoder(dg.Layer):
def
__init__
(
self
,
name_scope
,
embedding_size
,
num_hidden
,
config
):
def
__init__
(
self
,
name_scope
,
embedding_size
,
num_hidden
,
config
):
super
(
Encoder
,
self
).
__init__
(
name_scope
)
super
(
Encoder
,
self
).
__init__
(
name_scope
)
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
param
=
fluid
.
ParamAttr
(
name
=
'alpha'
)
param
=
fluid
.
ParamAttr
(
name
=
'alpha'
,
self
.
alpha
=
self
.
create_parameter
(
param
,
shape
=
(
1
,
),
dtype
=
'float32'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
))
default_initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
1.0
)
)
self
.
alpha
=
self
.
create_parameter
(
param
,
shape
=
(
1
,
),
dtype
=
'float32'
)
self
.
pos_inp
=
get_sinusoid_encoding_table
(
1024
,
self
.
num_hidden
,
padding_idx
=
0
)
self
.
pos_inp
=
get_sinusoid_encoding_table
(
1024
,
self
.
num_hidden
,
padding_idx
=
0
)
self
.
pos_emb
=
dg
.
Embedding
(
name_scope
=
self
.
full_name
(),
self
.
pos_emb
=
dg
.
Embedding
(
name_scope
=
self
.
full_name
(),
size
=
[
1024
,
num_hidden
],
size
=
[
1024
,
num_hidden
],
...
@@ -31,8 +31,8 @@ class Encoder(dg.Layer):
...
@@ -31,8 +31,8 @@ class Encoder(dg.Layer):
def
forward
(
self
,
x
,
positional
):
def
forward
(
self
,
x
,
positional
):
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
query_mask
=
(
positional
!=
0
).
astype
(
float
)
query_mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
(
positional
!=
0
).
astype
(
float
)
mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
(
mask
,[
1
]),
[
1
,
x
.
shape
[
1
],
1
])
mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
(
mask
,[
1
]),
[
1
,
x
.
shape
[
1
],
1
])
else
:
else
:
query_mask
,
mask
=
None
,
None
query_mask
,
mask
=
None
,
None
...
@@ -42,7 +42,7 @@ class Encoder(dg.Layer):
...
@@ -42,7 +42,7 @@ class Encoder(dg.Layer):
# Get positional encoding
# Get positional encoding
positional
=
self
.
pos_emb
(
fluid
.
layers
.
unsqueeze
(
positional
,
axes
=
[
-
1
])
)
positional
=
self
.
pos_emb
(
positional
)
x
=
positional
*
self
.
alpha
+
x
#(N, T, C)
x
=
positional
*
self
.
alpha
+
x
#(N, T, C)
...
@@ -102,14 +102,14 @@ class Decoder(dg.Layer):
...
@@ -102,14 +102,14 @@ class Decoder(dg.Layer):
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
#zeros = np.zeros(positional.shape, dtype=np.float32)
#zeros = np.zeros(positional.shape, dtype=np.float32)
m_mask
=
(
positional
!=
0
).
astype
(
float
)
m_mask
=
(
positional
!=
0
).
astype
(
np
.
float32
)
mask
=
np
.
repeat
(
np
.
expand_dims
(
m_mask
.
numpy
()
==
0
,
axis
=
1
),
decoder_len
,
axis
=
1
)
mask
=
np
.
repeat
(
np
.
expand_dims
(
m_mask
.
numpy
()
==
0
,
axis
=
1
),
decoder_len
,
axis
=
1
)
mask
=
mask
+
np
.
repeat
(
np
.
expand_dims
(
np
.
triu
(
np
.
ones
([
decoder_len
,
decoder_len
]),
1
),
axis
=
0
)
,
batch_size
,
axis
=
0
)
mask
=
mask
+
np
.
repeat
(
np
.
expand_dims
(
np
.
triu
(
np
.
ones
([
decoder_len
,
decoder_len
]),
1
),
axis
=
0
)
,
batch_size
,
axis
=
0
)
mask
=
fluid
.
layers
.
cast
(
dg
.
to_variable
(
mask
==
0
),
np
.
float32
)
mask
=
fluid
.
layers
.
cast
(
dg
.
to_variable
(
mask
==
0
),
np
.
float32
)
# (batch_size, decoder_len, decoder_len)
# (batch_size, decoder_len, decoder_len)
zero_mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
((
c_mask
!=
0
).
astype
(
float
),
axes
=
2
),
[
1
,
1
,
decoder_len
])
zero_mask
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
((
c_mask
!=
0
).
astype
(
np
.
float32
),
axes
=
2
),
[
1
,
1
,
decoder_len
])
# (batch_size, decoder_len, seq_len)
# (batch_size, decoder_len, seq_len)
zero_mask
=
fluid
.
layers
.
transpose
(
zero_mask
,
[
0
,
2
,
1
])
zero_mask
=
fluid
.
layers
.
transpose
(
zero_mask
,
[
0
,
2
,
1
])
...
@@ -125,7 +125,7 @@ class Decoder(dg.Layer):
...
@@ -125,7 +125,7 @@ class Decoder(dg.Layer):
query
=
self
.
linear
(
query
)
query
=
self
.
linear
(
query
)
# Get position embedding
# Get position embedding
positional
=
self
.
pos_emb
(
fluid
.
layers
.
unsqueeze
(
positional
,
axes
=
[
-
1
])
)
positional
=
self
.
pos_emb
(
positional
)
query
=
positional
*
self
.
alpha
+
query
query
=
positional
*
self
.
alpha
+
query
#positional dropout
#positional dropout
...
...
parakeet/models/transformerTTS/train_postnet.py
浏览文件 @
9fe6ad11
from
network
import
*
from
network
import
*
from
preprocess
import
batch_examples_postnet
,
LJSpeech
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
import
os
import
os
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
parakeet.data.datacargo
import
DataCargo
from
pathlib
import
Path
from
pathlib
import
Path
import
jsonargparse
import
jsonargparse
from
parse
import
add_config_options_to_parser
from
parse
import
add_config_options_to_parser
from
pprint
import
pprint
from
pprint
import
pprint
from
data
import
LJSpeechLoader
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
"""
"""
...
@@ -27,21 +26,15 @@ class MyDataParallel(dg.parallel.DataParallel):
...
@@ -27,21 +26,15 @@ class MyDataParallel(dg.parallel.DataParallel):
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
def
main
():
def
main
(
cfg
):
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train postnet model"
,
formatter_class
=
'default_argparse'
)
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_postnet.yaml'
.
split
())
local_rank
=
dg
.
parallel
.
Env
().
local_rank
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
cfg
.
use_data_parallel
else
0
nranks
=
dg
.
parallel
.
Env
().
nranks
if
cfg
.
use_data_parallel
else
1
if
local_rank
==
0
:
if
local_rank
==
0
:
# Print the whole config setting.
# Print the whole config setting.
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
LJSPEECH_ROOT
=
Path
(
cfg
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
dataloader
=
DataCargo
(
dataset
,
batch_size
=
cfg
.
batch_size
,
shuffle
=
True
,
collate_fn
=
batch_examples_postnet
,
drop_last
=
True
)
global_step
=
0
global_step
=
0
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
...
@@ -50,35 +43,10 @@ def main():
...
@@ -50,35 +43,10 @@ def main():
if
not
os
.
path
.
exists
(
cfg
.
log_dir
):
if
not
os
.
path
.
exists
(
cfg
.
log_dir
):
os
.
mkdir
(
cfg
.
log_dir
)
os
.
mkdir
(
cfg
.
log_dir
)
path
=
os
.
path
.
join
(
cfg
.
log_dir
,
'postnet'
)
path
=
os
.
path
.
join
(
cfg
.
log_dir
,
'postnet'
)
writer
=
SummaryWriter
(
path
)
with
dg
.
guard
(
place
):
# dataloader
input_fields
=
{
'names'
:
[
'mel'
,
'mag'
],
'shapes'
:
[[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
None
,
257
]],
'dtypes'
:
[
'float32'
,
'float32'
],
'lod_levels'
:
[
0
,
0
]
}
inputs
=
[
fluid
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))
]
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
inputs
,
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
writer
=
SummaryWriter
(
path
)
if
local_rank
==
0
else
None
with
dg
.
guard
(
place
):
model
=
ModelPostNet
(
'postnet'
,
cfg
)
model
=
ModelPostNet
(
'postnet'
,
cfg
)
model
.
train
()
model
.
train
()
...
@@ -94,9 +62,10 @@ def main():
...
@@ -94,9 +62,10 @@ def main():
strategy
=
dg
.
parallel
.
prepare_context
()
strategy
=
dg
.
parallel
.
prepare_context
()
model
=
MyDataParallel
(
model
,
strategy
)
model
=
MyDataParallel
(
model
,
strategy
)
reader
=
LJSpeechLoader
(
cfg
,
nranks
,
local_rank
,
is_postnet
=
True
).
reader
()
for
epoch
in
range
(
cfg
.
epochs
):
for
epoch
in
range
(
cfg
.
epochs
):
reader
.
set_batch_generator
(
dataloader
,
place
)
pbar
=
tqdm
(
reader
)
pbar
=
tqdm
(
reader
())
for
i
,
data
in
enumerate
(
pbar
):
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
mel
,
mag
=
data
mel
,
mag
=
data
...
@@ -109,17 +78,18 @@ def main():
...
@@ -109,17 +78,18 @@ def main():
loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
mag_pred
,
mag
)))
loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
mag_pred
,
mag
)))
if
cfg
.
use_data_parallel
:
if
cfg
.
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
loss
=
model
.
scale_loss
(
loss
)
writer
.
add_scalars
(
'training_loss'
,{
'loss'
:
loss
.
numpy
(),
},
global_step
)
loss
.
backward
()
loss
.
backward
()
if
cfg
.
use_data_parallel
:
model
.
apply_collective_grads
()
model
.
apply_collective_grads
()
else
:
loss
.
backward
()
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
model
.
clear_gradients
()
model
.
clear_gradients
()
if
local_rank
==
0
:
writer
.
add_scalars
(
'training_loss'
,{
'loss'
:
loss
.
numpy
(),
},
global_step
)
if
global_step
%
cfg
.
save_step
==
0
:
if
global_step
%
cfg
.
save_step
==
0
:
if
not
os
.
path
.
exists
(
cfg
.
save_path
):
if
not
os
.
path
.
exists
(
cfg
.
save_path
):
os
.
mkdir
(
cfg
.
save_path
)
os
.
mkdir
(
cfg
.
save_path
)
...
@@ -127,9 +97,11 @@ def main():
...
@@ -127,9 +97,11 @@ def main():
dg
.
save_dygraph
(
model
.
state_dict
(),
save_path
)
dg
.
save_dygraph
(
model
.
state_dict
(),
save_path
)
dg
.
save_dygraph
(
optimizer
.
state_dict
(),
save_path
)
dg
.
save_dygraph
(
optimizer
.
state_dict
(),
save_path
)
if
local_rank
==
0
:
writer
.
close
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train postnet model"
,
formatter_class
=
'default_argparse'
)
\ No newline at end of file
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_postnet.yaml'
.
split
())
main
(
cfg
)
\ No newline at end of file
parakeet/models/transformerTTS/train_transformer.py
浏览文件 @
9fe6ad11
from
preprocess
import
batch_examples
,
LJSpeech
import
os
import
os
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
from
network
import
*
from
network
import
*
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
parakeet.data.datacargo
import
DataCargo
from
pathlib
import
Path
from
pathlib
import
Path
import
jsonargparse
import
jsonargparse
from
parse
import
add_config_options_to_parser
from
parse
import
add_config_options_to_parser
from
pprint
import
pprint
from
pprint
import
pprint
from
matplotlib
import
cm
from
matplotlib
import
cm
from
data
import
LJSpeechLoader
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
class
MyDataParallel
(
dg
.
parallel
.
DataParallel
):
"""
"""
...
@@ -30,21 +29,14 @@ class MyDataParallel(dg.parallel.DataParallel):
...
@@ -30,21 +29,14 @@ class MyDataParallel(dg.parallel.DataParallel):
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
object
.
__getattribute__
(
self
,
"_sub_layers"
)[
"_layers"
],
key
)
def
main
():
def
main
(
cfg
):
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train TransformerTTS model"
,
formatter_class
=
'default_argparse'
)
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
cfg
.
use_data_parallel
else
0
add_config_options_to_parser
(
parser
)
nranks
=
dg
.
parallel
.
Env
().
nranks
if
cfg
.
use_data_parallel
else
1
cfg
=
parser
.
parse_args
(
'-c ./config/train_transformer.yaml'
.
split
())
local_rank
=
dg
.
parallel
.
Env
().
local_rank
if
local_rank
==
0
:
if
local_rank
==
0
:
# Print the whole config setting.
# Print the whole config setting.
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
pprint
(
jsonargparse
.
namespace_to_dict
(
cfg
))
LJSPEECH_ROOT
=
Path
(
cfg
.
data_path
)
dataset
=
LJSpeech
(
LJSPEECH_ROOT
)
dataloader
=
DataCargo
(
dataset
,
batch_size
=
cfg
.
batch_size
,
shuffle
=
True
,
collate_fn
=
batch_examples
,
drop_last
=
True
)
global_step
=
0
global_step
=
0
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
place
=
(
fluid
.
CUDAPlace
(
dg
.
parallel
.
Env
().
dev_id
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
...
@@ -57,39 +49,13 @@ def main():
...
@@ -57,39 +49,13 @@ def main():
writer
=
SummaryWriter
(
path
)
if
local_rank
==
0
else
None
writer
=
SummaryWriter
(
path
)
if
local_rank
==
0
else
None
with
dg
.
guard
(
place
):
with
dg
.
guard
(
place
):
if
cfg
.
use_data_parallel
:
strategy
=
dg
.
parallel
.
prepare_context
()
# dataloader
input_fields
=
{
'names'
:
[
'character'
,
'mel'
,
'mel_input'
,
'pos_text'
,
'pos_mel'
,
'text_len'
],
'shapes'
:
[[
cfg
.
batch_size
,
None
],
[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
None
,
80
],
[
cfg
.
batch_size
,
1
],
[
cfg
.
batch_size
,
1
],
[
cfg
.
batch_size
,
1
]],
'dtypes'
:
[
'float32'
,
'float32'
,
'float32'
,
'int64'
,
'int64'
,
'int64'
],
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
,
0
]
}
inputs
=
[
fluid
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))
]
reader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
inputs
,
capacity
=
32
,
iterable
=
True
,
use_double_buffer
=
True
,
return_list
=
True
)
model
=
Model
(
'transtts'
,
cfg
)
model
=
Model
(
'transtts'
,
cfg
)
model
.
train
()
model
.
train
()
optimizer
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
dg
.
NoamDecay
(
1
/
(
4000
*
(
cfg
.
lr
**
2
)),
4000
))
optimizer
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
dg
.
NoamDecay
(
1
/
(
4000
*
(
cfg
.
lr
**
2
)),
4000
))
reader
=
LJSpeechLoader
(
cfg
,
nranks
,
local_rank
).
reader
()
if
cfg
.
checkpoint_path
is
not
None
:
if
cfg
.
checkpoint_path
is
not
None
:
model_dict
,
opti_dict
=
fluid
.
dygraph
.
load_dygraph
(
cfg
.
checkpoint_path
)
model_dict
,
opti_dict
=
fluid
.
dygraph
.
load_dygraph
(
cfg
.
checkpoint_path
)
model
.
set_dict
(
model_dict
)
model
.
set_dict
(
model_dict
)
...
@@ -97,11 +63,11 @@ def main():
...
@@ -97,11 +63,11 @@ def main():
print
(
"load checkpoint!!!"
)
print
(
"load checkpoint!!!"
)
if
cfg
.
use_data_parallel
:
if
cfg
.
use_data_parallel
:
strategy
=
dg
.
parallel
.
prepare_context
()
model
=
MyDataParallel
(
model
,
strategy
)
model
=
MyDataParallel
(
model
,
strategy
)
for
epoch
in
range
(
cfg
.
epochs
):
for
epoch
in
range
(
cfg
.
epochs
):
reader
.
set_batch_generator
(
dataloader
,
place
)
pbar
=
tqdm
(
reader
)
pbar
=
tqdm
(
reader
())
for
i
,
data
in
enumerate
(
pbar
):
for
i
,
data
in
enumerate
(
pbar
):
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
pbar
.
set_description
(
'Processing at epoch %d'
%
epoch
)
character
,
mel
,
mel_input
,
pos_text
,
pos_mel
,
text_length
=
data
character
,
mel
,
mel_input
,
pos_text
,
pos_mel
,
text_length
=
data
...
@@ -114,9 +80,7 @@ def main():
...
@@ -114,9 +80,7 @@ def main():
post_mel_loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
postnet_pred
,
mel
)))
post_mel_loss
=
layers
.
mean
(
layers
.
abs
(
layers
.
elementwise_sub
(
postnet_pred
,
mel
)))
loss
=
mel_loss
+
post_mel_loss
loss
=
mel_loss
+
post_mel_loss
if
cfg
.
use_data_parallel
:
if
local_rank
==
0
:
loss
=
model
.
scale_loss
(
loss
)
writer
.
add_scalars
(
'training_loss'
,
{
writer
.
add_scalars
(
'training_loss'
,
{
'mel_loss'
:
mel_loss
.
numpy
(),
'mel_loss'
:
mel_loss
.
numpy
(),
'post_mel_loss'
:
post_mel_loss
.
numpy
(),
'post_mel_loss'
:
post_mel_loss
.
numpy
(),
...
@@ -145,9 +109,12 @@ def main():
...
@@ -145,9 +109,12 @@ def main():
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
16
])
*
255
)
x
=
np
.
uint8
(
cm
.
viridis
(
prob
.
numpy
()[
j
*
16
])
*
255
)
writer
.
add_image
(
'Attention_dec_%d_0'
%
global_step
,
x
,
i
*
4
+
j
,
dataformats
=
"HWC"
)
writer
.
add_image
(
'Attention_dec_%d_0'
%
global_step
,
x
,
i
*
4
+
j
,
dataformats
=
"HWC"
)
loss
.
backward
()
if
cfg
.
use_data_parallel
:
if
cfg
.
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
loss
.
backward
()
model
.
apply_collective_grads
()
model
.
apply_collective_grads
()
else
:
loss
.
backward
()
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
optimizer
.
minimize
(
loss
,
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
1
))
model
.
clear_gradients
()
model
.
clear_gradients
()
...
@@ -163,4 +130,7 @@ def main():
...
@@ -163,4 +130,7 @@ def main():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
parser
=
jsonargparse
.
ArgumentParser
(
description
=
"Train TransformerTTS model"
,
formatter_class
=
'default_argparse'
)
\ No newline at end of file
add_config_options_to_parser
(
parser
)
cfg
=
parser
.
parse_args
(
'-c ./config/train_transformer.yaml'
.
split
())
main
(
cfg
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录