Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
fb0acd40
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fb0acd40
编写于
1月 24, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wavernn, test=tts
上级
49fd55dc
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
1736 addition
and
0 deletion
+1736
-0
examples/csmsc/voc6/conf/default.yaml
examples/csmsc/voc6/conf/default.yaml
+68
-0
examples/csmsc/voc6/local/preprocess.sh
examples/csmsc/voc6/local/preprocess.sh
+15
-0
examples/csmsc/voc6/local/synthesize.sh
examples/csmsc/voc6/local/synthesize.sh
+14
-0
examples/csmsc/voc6/local/train.sh
examples/csmsc/voc6/local/train.sh
+9
-0
examples/csmsc/voc6/path.sh
examples/csmsc/voc6/path.sh
+13
-0
examples/csmsc/voc6/run.sh
examples/csmsc/voc6/run.sh
+33
-0
paddlespeech/t2s/datasets/__init__.py
paddlespeech/t2s/datasets/__init__.py
+1
-0
paddlespeech/t2s/datasets/csmsc.py
paddlespeech/t2s/datasets/csmsc.py
+56
-0
paddlespeech/t2s/datasets/vocoder_batch_fn.py
paddlespeech/t2s/datasets/vocoder_batch_fn.py
+125
-0
paddlespeech/t2s/exps/wavernn/__init__.py
paddlespeech/t2s/exps/wavernn/__init__.py
+13
-0
paddlespeech/t2s/exps/wavernn/preprocess.py
paddlespeech/t2s/exps/wavernn/preprocess.py
+157
-0
paddlespeech/t2s/exps/wavernn/synthesize.py
paddlespeech/t2s/exps/wavernn/synthesize.py
+89
-0
paddlespeech/t2s/exps/wavernn/train.py
paddlespeech/t2s/exps/wavernn/train.py
+192
-0
paddlespeech/t2s/models/__init__.py
paddlespeech/t2s/models/__init__.py
+1
-0
paddlespeech/t2s/models/wavernn/__init__.py
paddlespeech/t2s/models/wavernn/__init__.py
+15
-0
paddlespeech/t2s/models/wavernn/wavernn.py
paddlespeech/t2s/models/wavernn/wavernn.py
+592
-0
paddlespeech/t2s/models/wavernn/wavernn_updater.py
paddlespeech/t2s/models/wavernn/wavernn_updater.py
+203
-0
paddlespeech/t2s/modules/losses.py
paddlespeech/t2s/modules/losses.py
+140
-0
未找到文件。
examples/csmsc/voc6/conf/default.yaml
0 → 100644
浏览文件 @
fb0acd40
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs
:
24000
# Sampling rate.
n_fft
:
2048
# FFT size (samples).
n_shift
:
300
# Hop size (samples). 12.5ms
win_length
:
1200
# Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window
:
"
hann"
# Window function.
n_mels
:
80
# Number of mel basis.
fmin
:
80
# Minimum freq in mel basis calculation. (Hz)
fmax
:
7600
# Maximum frequency in mel basis calculation. (Hz)
mu_law
:
True
# Recommended to suppress noise if using raw bitsexit()
peak_norm
:
True
###########################################################
# MODEL SETTING #
###########################################################
model
:
rnn_dims
:
512
# Hidden dims of RNN Layers.
fc_dims
:
512
bits
:
9
# Bit depth of signal
aux_context_window
:
2
aux_channels
:
80
# Number of channels for auxiliary feature conv.
# Must be the same as num_mels.
upsample_scales
:
[
4
,
5
,
3
,
5
]
# Upsampling scales. Prodcut of these must be the same as hop size, same with pwgan here
compute_dims
:
128
res_out_dims
:
128
res_blocks
:
10
mode
:
RAW
# either 'raw'(softmax on raw bits) or 'mold' (sample from mixture of logistics)
inference
:
gen_batched
:
True
# whether to genenate sample in batch mode
target
:
12000
# target number of samples to be generated in each batch entry
overlap
:
600
# number of samples for crossfading between batches
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
64
# Batch size.
batch_max_steps
:
4500
# Length of each audio in batch. Make sure dividable by hop_size.
num_workers
:
2
# Number of workers in DataLoader.
valid_size
:
50
###########################################################
# OPTIMIZER SETTING #
###########################################################
grad_clip
:
4.0
learning_rate
:
1.0e-4
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps
:
400000
# Number of training steps.
save_interval_steps
:
5000
# Interval steps to save checkpoint.
eval_interval_steps
:
1000
# Interval steps to evaluate the network.
gen_eval_samples_interval_steps
:
5000
# the iteration interval of generating valid samples
generate_num
:
5
# number of samples to generate at each checkpoint
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots
:
10
# max number of snapshots to keep while training
seed
:
42
# random seed for paddle, random, and np.random
examples/csmsc/voc6/local/preprocess.sh
0 → 100755
浏览文件 @
fb0acd40
#!/bin/bash
stage
=
0
stop_stage
=
100
config_path
=
$1
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
python3
${
BIN_DIR
}
/preprocess.py
\
--input
=
~/datasets/BZNSYP/
\
--output
=
dump
\
--dataset
=
csmsc
\
--config
=
${
config_path
}
\
--num-cpu
=
20
fi
examples/csmsc/voc6/local/synthesize.sh
0 → 100755
浏览文件 @
fb0acd40
#!/bin/bash
config_path
=
$1
train_output_path
=
$2
ckpt_name
=
$3
test_input
=
$4
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/synthesize.py
\
--config
=
${
config_path
}
\
--checkpoint
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--input
=
${
test_input
}
\
--output-dir
=
${
train_output_path
}
/test
examples/csmsc/voc6/local/train.sh
0 → 100755
浏览文件 @
fb0acd40
#!/bin/bash
config_path
=
$1
train_output_path
=
$2
python
${
BIN_DIR
}
/train.py
\
--config
=
${
config_path
}
\
--data
=
dump/
\
--output-dir
=
${
train_output_path
}
\
--ngpu
=
1
examples/csmsc/voc6/path.sh
0 → 100755
浏览文件 @
fb0acd40
#!/bin/bash
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
export
PYTHONDONTWRITEBYTECODE
=
1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
MODEL
=
wavernn
export
BIN_DIR
=
${
MAIN_ROOT
}
/paddlespeech/t2s/exps/
${
MODEL
}
\ No newline at end of file
examples/csmsc/voc6/run.sh
0 → 100755
浏览文件 @
fb0acd40
#!/bin/bash
set
-e
source
path.sh
gpus
=
0,1
stage
=
0
stop_stage
=
100
conf_path
=
conf/default.yaml
train_output_path
=
exp/default
test_input
=
dump/mel_test
ckpt_name
=
snapshot_iter_100000.pdz
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
./local/preprocess.sh
${
conf_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# prepare data
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
train_output_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# copy some test mels from dump
mkdir
-p
${
test_input
}
cp
-r
dump/mel/00995
*
.npy
${
test_input
}
# synthesize
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
${
test_input
}
||
exit
-1
fi
paddlespeech/t2s/datasets/__init__.py
浏览文件 @
fb0acd40
...
...
@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.common
import
*
from
.csmsc
import
*
from
.ljspeech
import
*
paddlespeech/t2s/datasets/csmsc.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
pathlib
import
Path
from
paddle.io
import
Dataset
__all__
=
[
"CSMSCMetaData"
]
class
CSMSCMetaData
(
Dataset
):
def
__init__
(
self
,
root
):
"""
:param root: the path of baker dataset
"""
self
.
root
=
os
.
path
.
abspath
(
root
)
records
=
[]
index
=
1
self
.
meta_info
=
[
"file_path"
,
"text"
,
"pinyin"
]
metadata_path
=
os
.
path
.
join
(
root
,
"ProsodyLabeling/000001-010000.txt"
)
wav_dirs
=
os
.
path
.
join
(
self
.
root
,
"Wave"
)
with
open
(
metadata_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
while
True
:
line1
=
f
.
readline
().
strip
()
if
not
line1
:
break
line2
=
f
.
readline
().
strip
()
strs
=
line1
.
split
()
wav_fname
=
line1
.
split
()[
0
].
strip
()
+
'.wav'
wav_filepath
=
os
.
path
.
join
(
wav_dirs
,
wav_fname
)
text
=
strs
[
1
].
strip
()
pinyin
=
line2
records
.
append
([
wav_filepath
,
text
,
pinyin
])
self
.
records
=
records
def
__getitem__
(
self
,
i
):
return
self
.
records
[
i
]
def
__len__
(
self
):
return
len
(
self
.
records
)
def
get_meta_info
(
self
):
return
self
.
meta_info
paddlespeech/t2s/datasets/vocoder_batch_fn.py
浏览文件 @
fb0acd40
...
...
@@ -11,8 +11,133 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
pathlib
import
Path
import
numpy
as
np
import
paddle
from
paddle.io
import
Dataset
def
label_2_float
(
x
,
bits
):
return
2
*
x
/
(
2
**
bits
-
1.
)
-
1.
def
float_2_label
(
x
,
bits
):
assert
abs
(
x
).
max
()
<=
1.0
x
=
(
x
+
1.
)
*
(
2
**
bits
-
1
)
/
2
return
x
.
clip
(
0
,
2
**
bits
-
1
)
def
encode_mu_law
(
x
,
mu
):
mu
=
mu
-
1
fx
=
np
.
sign
(
x
)
*
np
.
log
(
1
+
mu
*
np
.
abs
(
x
))
/
np
.
log
(
1
+
mu
)
return
np
.
floor
((
fx
+
1
)
/
2
*
mu
+
0.5
)
def
decode_mu_law
(
y
,
mu
,
from_labels
=
True
):
# TODO: get rid of log2 - makes no sense
if
from_labels
:
y
=
label_2_float
(
y
,
math
.
log2
(
mu
))
mu
=
mu
-
1
x
=
paddle
.
sign
(
y
)
/
mu
*
((
1
+
mu
)
**
paddle
.
abs
(
y
)
-
1
)
return
x
class
WaveRNNDataset
(
Dataset
):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def
__init__
(
self
,
root
):
self
.
root
=
Path
(
root
).
expanduser
()
records
=
[]
with
open
(
self
.
root
/
"metadata.csv"
,
'r'
)
as
rf
:
for
line
in
rf
:
name
=
line
.
split
(
"
\t
"
)[
0
]
mel_path
=
str
(
self
.
root
/
"mel"
/
(
str
(
name
)
+
".npy"
))
wav_path
=
str
(
self
.
root
/
"wav"
/
(
str
(
name
)
+
".npy"
))
records
.
append
((
mel_path
,
wav_path
))
self
.
records
=
records
def
__getitem__
(
self
,
i
):
mel_name
,
wav_name
=
self
.
records
[
i
]
mel
=
np
.
load
(
mel_name
)
wav
=
np
.
load
(
wav_name
)
return
mel
,
wav
def
__len__
(
self
):
return
len
(
self
.
records
)
class
WaveRNNClip
(
object
):
def
__init__
(
self
,
mode
:
str
=
'RAW'
,
batch_max_steps
:
int
=
4500
,
hop_size
:
int
=
300
,
aux_context_window
:
int
=
2
,
bits
:
int
=
9
):
self
.
mode
=
mode
self
.
mel_win
=
batch_max_steps
//
hop_size
+
2
*
aux_context_window
self
.
batch_max_steps
=
batch_max_steps
self
.
hop_size
=
hop_size
self
.
aux_context_window
=
aux_context_window
if
self
.
mode
==
'MOL'
:
self
.
bits
=
16
else
:
self
.
bits
=
bits
def
__call__
(
self
,
batch
):
# batch: [mel, quant]
# voc_pad = 2 this will pad the input so that the resnet can 'see' wider than input length
# max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15
max_offsets
=
[
x
[
0
].
shape
[
-
1
]
-
2
-
(
self
.
mel_win
+
2
*
self
.
aux_context_window
)
for
x
in
batch
]
# the slice point of mel selecting randomly
mel_offsets
=
[
np
.
random
.
randint
(
0
,
offset
)
for
offset
in
max_offsets
]
# the slice point of wav selecting randomly, which is behind 2(=pad) frames
sig_offsets
=
[(
offset
+
self
.
aux_context_window
)
*
self
.
hop_size
for
offset
in
mel_offsets
]
# mels.sape[1] = voc_seq_len // hop_length + 2 * voc_pad
mels
=
[
x
[
0
][:,
mel_offsets
[
i
]:
mel_offsets
[
i
]
+
self
.
mel_win
]
for
i
,
x
in
enumerate
(
batch
)
]
# label.shape[1] = voc_seq_len + 1
labels
=
[
x
[
1
][
sig_offsets
[
i
]:
sig_offsets
[
i
]
+
self
.
batch_max_steps
+
1
]
for
i
,
x
in
enumerate
(
batch
)
]
mels
=
np
.
stack
(
mels
).
astype
(
np
.
float32
)
labels
=
np
.
stack
(
labels
).
astype
(
np
.
int64
)
mels
=
paddle
.
to_tensor
(
mels
)
labels
=
paddle
.
to_tensor
(
labels
,
dtype
=
'int64'
)
# x is input, y is label
x
=
labels
[:,
:
self
.
batch_max_steps
]
y
=
labels
[:,
1
:]
'''
mode = RAW:
mu_law = True:
quant: bits = 9 0, 1, 2, ..., 509, 510, 511 int
mu_law = False
quant bits = 9 [0, 511] float
mode = MOL:
quant: bits = 16 [0. 65536] float
'''
# x should be normalizes in.[0, 1] in RAW mode
x
=
label_2_float
(
paddle
.
cast
(
x
,
dtype
=
'float32'
),
self
.
bits
)
# y should be normalizes in.[0, 1] in MOL mode
if
self
.
mode
==
'MOL'
:
y
=
label_2_float
(
paddle
.
cast
(
y
,
dtype
=
'float32'
),
self
.
bits
)
return
x
,
y
,
mels
class
Clip
(
object
):
...
...
paddlespeech/t2s/exps/wavernn/__init__.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/t2s/exps/wavernn/preprocess.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
from
multiprocessing
import
cpu_count
from
multiprocessing
import
Pool
from
pathlib
import
Path
import
librosa
import
numpy
as
np
import
pandas
as
pd
import
tqdm
import
yaml
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.data.get_feats
import
LogMelFBank
from
paddlespeech.t2s.datasets
import
CSMSCMetaData
from
paddlespeech.t2s.datasets
import
LJSpeechMetaData
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
encode_mu_law
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
float_2_label
class
Transform
(
object
):
def
__init__
(
self
,
output_dir
:
Path
,
config
):
self
.
fs
=
config
.
fs
self
.
peak_norm
=
config
.
peak_norm
self
.
bits
=
config
.
model
.
bits
self
.
mode
=
config
.
model
.
mode
self
.
mu_law
=
config
.
mu_law
self
.
wav_dir
=
output_dir
/
"wav"
self
.
mel_dir
=
output_dir
/
"mel"
self
.
wav_dir
.
mkdir
(
exist_ok
=
True
)
self
.
mel_dir
.
mkdir
(
exist_ok
=
True
)
self
.
mel_extractor
=
LogMelFBank
(
sr
=
config
.
fs
,
n_fft
=
config
.
n_fft
,
hop_length
=
config
.
n_shift
,
win_length
=
config
.
win_length
,
window
=
config
.
window
,
n_mels
=
config
.
n_mels
,
fmin
=
config
.
fmin
,
fmax
=
config
.
fmax
)
if
self
.
mode
!=
'RAW'
and
self
.
mode
!=
'MOL'
:
raise
RuntimeError
(
'Unknown mode value - '
,
self
.
mode
)
def
__call__
(
self
,
example
):
wav_path
,
_
,
_
=
example
base_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
wav_path
))[
0
]
# print("self.sample_rate:",self.sample_rate)
wav
,
_
=
librosa
.
load
(
wav_path
,
sr
=
self
.
fs
)
peak
=
np
.
abs
(
wav
).
max
()
if
self
.
peak_norm
or
peak
>
1.0
:
wav
/=
peak
mel
=
self
.
mel_extractor
.
get_log_mel_fbank
(
wav
).
T
if
self
.
mode
==
'RAW'
:
if
self
.
mu_law
:
quant
=
encode_mu_law
(
wav
,
mu
=
2
**
self
.
bits
)
else
:
quant
=
float_2_label
(
wav
,
bits
=
self
.
bits
)
elif
self
.
mode
==
'MOL'
:
quant
=
float_2_label
(
wav
,
bits
=
16
)
mel
=
mel
.
astype
(
np
.
float32
)
audio
=
quant
.
astype
(
np
.
int64
)
np
.
save
(
str
(
self
.
wav_dir
/
base_name
),
audio
)
np
.
save
(
str
(
self
.
mel_dir
/
base_name
),
mel
)
return
base_name
,
mel
.
shape
[
-
1
],
audio
.
shape
[
-
1
]
def
create_dataset
(
config
,
input_dir
,
output_dir
,
nprocs
:
int
=
1
,
dataset_type
:
str
=
"ljspeech"
):
input_dir
=
Path
(
input_dir
).
expanduser
()
'''
LJSpeechMetaData.records: [filename, normalized text, speaker name(ljspeech)]
CSMSCMetaData.records: [filename, normalized text, pinyin]
'''
if
dataset_type
==
'ljspeech'
:
dataset
=
LJSpeechMetaData
(
input_dir
)
else
:
dataset
=
CSMSCMetaData
(
input_dir
)
output_dir
=
Path
(
output_dir
).
expanduser
()
output_dir
.
mkdir
(
exist_ok
=
True
)
transform
=
Transform
(
output_dir
,
config
)
file_names
=
[]
pool
=
Pool
(
processes
=
nprocs
)
for
info
in
tqdm
.
tqdm
(
pool
.
imap
(
transform
,
dataset
),
total
=
len
(
dataset
)):
base_name
,
mel_len
,
audio_len
=
info
file_names
.
append
((
base_name
,
mel_len
,
audio_len
))
meta_data
=
pd
.
DataFrame
.
from_records
(
file_names
)
meta_data
.
to_csv
(
str
(
output_dir
/
"metadata.csv"
),
sep
=
"
\t
"
,
index
=
None
,
header
=
None
)
print
(
"saved meta data in to {}"
.
format
(
os
.
path
.
join
(
output_dir
,
"metadata.csv"
)))
print
(
"Done!"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"create dataset"
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"config file to overwrite default config."
)
parser
.
add_argument
(
"--input"
,
type
=
str
,
help
=
"path of the ljspeech dataset"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
help
=
"path to save output dataset"
)
parser
.
add_argument
(
"--num-cpu"
,
type
=
int
,
default
=
cpu_count
()
//
2
,
help
=
"number of process."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"ljspeech"
,
help
=
"The dataset to preprocess, ljspeech or csmsc"
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
'rt'
)
as
f
:
config
=
CfgNode
(
yaml
.
safe_load
(
f
))
if
args
.
dataset
!=
"ljspeech"
and
args
.
dataset
!=
"csmsc"
:
raise
RuntimeError
(
'Unknown dataset - '
,
args
.
dataset
)
create_dataset
(
config
,
input_dir
=
args
.
input
,
output_dir
=
args
.
output
,
nprocs
=
args
.
num_cpu
,
dataset_type
=
args
.
dataset
)
paddlespeech/t2s/exps/wavernn/synthesize.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
from
pathlib
import
Path
import
numpy
as
np
import
paddle
import
soundfile
as
sf
import
yaml
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.models.wavernn
import
WaveRNN
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Synthesize with WaveRNN."
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"GANVocoder config file."
)
parser
.
add_argument
(
"--checkpoint"
,
type
=
str
,
help
=
"snapshot to load."
)
parser
.
add_argument
(
"--input"
,
type
=
str
,
help
=
"path of directory containing mel spectrogram (in .npy format)"
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"output dir."
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
)
as
f
:
config
=
CfgNode
(
yaml
.
safe_load
(
f
))
print
(
"========Args========"
)
print
(
yaml
.
safe_dump
(
vars
(
args
)))
print
(
"========Config========"
)
print
(
config
)
print
(
f
"master see the word size:
{
dist
.
get_world_size
()
}
, from pid:
{
os
.
getpid
()
}
"
)
if
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
elif
args
.
ngpu
>
0
:
paddle
.
set_device
(
"gpu"
)
else
:
print
(
"ngpu should >= 0 !"
)
model
=
WaveRNN
(
hop_length
=
config
.
n_shift
,
sample_rate
=
config
.
fs
,
**
config
[
"model"
])
state_dict
=
paddle
.
load
(
args
.
checkpoint
)
model
.
set_state_dict
(
state_dict
[
"main_params"
])
model
.
eval
()
mel_dir
=
Path
(
args
.
input
).
expanduser
()
output_dir
=
Path
(
args
.
output_dir
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
for
file_path
in
sorted
(
mel_dir
.
iterdir
()):
mel
=
np
.
load
(
str
(
file_path
))
mel
=
paddle
.
to_tensor
(
mel
)
mel
=
mel
.
transpose
([
1
,
0
])
# input shape is (T', C_aux)
audio
=
model
.
generate
(
c
=
mel
,
batched
=
config
.
inference
.
gen_batched
,
target
=
config
.
inference
.
target
,
overlap
=
config
.
inference
.
overlap
,
mu_law
=
config
.
mu_law
,
gen_display
=
True
)
audio_path
=
output_dir
/
(
os
.
path
.
splitext
(
file_path
.
name
)[
0
]
+
".wav"
)
sf
.
write
(
audio_path
,
audio
.
numpy
(),
samplerate
=
config
.
fs
)
print
(
"[synthesize] {} -> {}"
.
format
(
file_path
,
audio_path
))
if
__name__
==
"__main__"
:
main
()
paddlespeech/t2s/exps/wavernn/train.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
shutil
from
pathlib
import
Path
import
paddle
import
yaml
from
paddle
import
DataParallel
from
paddle
import
distributed
as
dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.optimizer
import
Adam
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.data
import
dataset
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
WaveRNNClip
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
WaveRNNDataset
from
paddlespeech.t2s.models.wavernn
import
WaveRNN
from
paddlespeech.t2s.models.wavernn
import
WaveRNNEvaluator
from
paddlespeech.t2s.models.wavernn
import
WaveRNNUpdater
from
paddlespeech.t2s.modules.losses
import
discretized_mix_logistic_loss
from
paddlespeech.t2s.training.extensions.snapshot
import
Snapshot
from
paddlespeech.t2s.training.extensions.visualizer
import
VisualDL
from
paddlespeech.t2s.training.seeding
import
seed_everything
from
paddlespeech.t2s.training.trainer
import
Trainer
def
train_sp
(
args
,
config
):
# decides device type and whether to run in parallel
# setup running environment correctly
world_size
=
paddle
.
distributed
.
get_world_size
()
if
(
not
paddle
.
is_compiled_with_cuda
())
or
args
.
ngpu
==
0
:
paddle
.
set_device
(
"cpu"
)
else
:
paddle
.
set_device
(
"gpu"
)
if
world_size
>
1
:
paddle
.
distributed
.
init_parallel_env
()
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
print
(
f
"rank:
{
dist
.
get_rank
()
}
, pid:
{
os
.
getpid
()
}
, parent_pid:
{
os
.
getppid
()
}
"
,
)
wavernn_dataset
=
WaveRNNDataset
(
args
.
data
)
train_dataset
,
dev_dataset
=
dataset
.
split
(
wavernn_dataset
,
len
(
wavernn_dataset
)
-
config
.
valid_size
)
batch_fn
=
WaveRNNClip
(
mode
=
config
.
model
.
mode
,
aux_context_window
=
config
.
model
.
aux_context_window
,
hop_size
=
config
.
n_shift
,
batch_max_steps
=
config
.
batch_max_steps
,
bits
=
config
.
model
.
bits
)
# collate function and dataloader
train_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
dev_sampler
=
DistributedBatchSampler
(
dev_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
)
print
(
"samplers done!"
)
train_dataloader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
collate_fn
=
batch_fn
,
num_workers
=
config
.
num_workers
)
dev_dataloader
=
DataLoader
(
dev_dataset
,
collate_fn
=
batch_fn
,
batch_sampler
=
dev_sampler
,
num_workers
=
config
.
num_workers
)
valid_generate_loader
=
DataLoader
(
dev_dataset
,
batch_size
=
1
)
print
(
"dataloaders done!"
)
model
=
WaveRNN
(
hop_length
=
config
.
n_shift
,
sample_rate
=
config
.
fs
,
**
config
[
"model"
])
if
world_size
>
1
:
model
=
DataParallel
(
model
)
print
(
"model done!"
)
if
config
.
model
.
mode
==
'RAW'
:
criterion
=
paddle
.
nn
.
CrossEntropyLoss
(
axis
=
1
)
elif
config
.
model
.
mode
==
'MOL'
:
criterion
=
discretized_mix_logistic_loss
else
:
criterion
=
None
RuntimeError
(
'Unknown model mode value - '
,
config
.
model
.
mode
)
print
(
"criterions done!"
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
config
.
grad_clip
)
optimizer
=
Adam
(
parameters
=
model
.
parameters
(),
learning_rate
=
config
.
learning_rate
,
grad_clip
=
clip
)
print
(
"optimizer done!"
)
output_dir
=
Path
(
args
.
output_dir
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
dist
.
get_rank
()
==
0
:
config_name
=
args
.
config
.
split
(
"/"
)[
-
1
]
# copy conf to output_dir
shutil
.
copyfile
(
args
.
config
,
output_dir
/
config_name
)
updater
=
WaveRNNUpdater
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
dataloader
=
train_dataloader
,
output_dir
=
output_dir
,
mode
=
config
.
model
.
mode
)
evaluator
=
WaveRNNEvaluator
(
model
=
model
,
dataloader
=
dev_dataloader
,
criterion
=
criterion
,
output_dir
=
output_dir
,
valid_generate_loader
=
valid_generate_loader
,
config
=
config
)
trainer
=
Trainer
(
updater
,
stop_trigger
=
(
config
.
train_max_steps
,
"iteration"
),
out
=
output_dir
)
if
dist
.
get_rank
()
==
0
:
trainer
.
extend
(
evaluator
,
trigger
=
(
config
.
eval_interval_steps
,
'iteration'
))
trainer
.
extend
(
VisualDL
(
output_dir
),
trigger
=
(
1
,
'iteration'
))
trainer
.
extend
(
Snapshot
(
max_size
=
config
.
num_snapshots
),
trigger
=
(
config
.
save_interval_steps
,
'iteration'
))
print
(
"Trainer Done!"
)
trainer
.
run
()
def
main
():
# parse args and config and redirect to train_sp
parser
=
argparse
.
ArgumentParser
(
description
=
"Train a WaveRNN model."
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"config file to overwrite default config."
)
parser
.
add_argument
(
"--data"
,
type
=
str
,
help
=
"input"
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"output dir."
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu."
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
'rt'
)
as
f
:
config
=
CfgNode
(
yaml
.
safe_load
(
f
))
print
(
"========Args========"
)
print
(
yaml
.
safe_dump
(
vars
(
args
)))
print
(
"========Config========"
)
print
(
config
)
print
(
f
"master see the word size:
{
dist
.
get_world_size
()
}
, from pid:
{
os
.
getpid
()
}
"
)
# dispatch
if
args
.
ngpu
>
1
:
dist
.
spawn
(
train_sp
,
(
args
,
config
),
nprocs
=
args
.
ngpu
)
else
:
train_sp
(
args
,
config
)
if
__name__
==
"__main__"
:
main
()
paddlespeech/t2s/models/__init__.py
浏览文件 @
fb0acd40
...
...
@@ -20,3 +20,4 @@ from .speedyspeech import *
from
.tacotron2
import
*
from
.transformer_tts
import
*
from
.waveflow
import
*
from
.wavernn
import
*
paddlespeech/t2s/models/wavernn/__init__.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.wavernn
import
*
from
.wavernn_updater
import
*
paddlespeech/t2s/models/wavernn/wavernn.py
0 → 100644
浏览文件 @
fb0acd40
此差异已折叠。
点击以展开。
paddlespeech/t2s/models/wavernn/wavernn_updater.py
0 → 100644
浏览文件 @
fb0acd40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
pathlib
import
Path
import
paddle
import
soundfile
as
sf
from
paddle
import
distributed
as
dist
from
paddle.io
import
DataLoader
from
paddle.nn
import
Layer
from
paddle.optimizer
import
Optimizer
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
decode_mu_law
from
paddlespeech.t2s.datasets.vocoder_batch_fn
import
label_2_float
from
paddlespeech.t2s.training.extensions.evaluator
import
StandardEvaluator
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.updaters.standard_updater
import
StandardUpdater
logging
.
basicConfig
(
format
=
'%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s'
,
datefmt
=
'[%Y-%m-%d %H:%M:%S]'
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
def
calculate_grad_norm
(
parameters
,
norm_type
:
str
=
2
):
'''
calculate grad norm of mdoel's parameters
parameters:
model's parameters
norm_type: str
Returns
------------
Tensor
grad_norm
'''
grad_list
=
[
paddle
.
to_tensor
(
p
.
grad
)
for
p
in
parameters
if
p
.
grad
is
not
None
]
norm_list
=
paddle
.
stack
(
[
paddle
.
norm
(
grad
,
norm_type
)
for
grad
in
grad_list
])
total_norm
=
paddle
.
norm
(
norm_list
)
return
total_norm
# for save name in gen_valid_samples()
ITERATION
=
0
class
WaveRNNUpdater
(
StandardUpdater
):
def
__init__
(
self
,
model
:
Layer
,
optimizer
:
Optimizer
,
criterion
:
Layer
,
dataloader
:
DataLoader
,
init_state
=
None
,
output_dir
:
Path
=
None
,
mode
=
'RAW'
):
super
().
__init__
(
model
,
optimizer
,
dataloader
,
init_state
=
None
)
self
.
criterion
=
criterion
# self.scheduler = scheduler
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
self
.
filehandler
=
logging
.
FileHandler
(
str
(
log_file
))
logger
.
addHandler
(
self
.
filehandler
)
self
.
logger
=
logger
self
.
msg
=
""
self
.
mode
=
mode
def
update_core
(
self
,
batch
):
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
losses_dict
=
{}
# parse batch
self
.
model
.
train
()
self
.
optimizer
.
clear_grad
()
wav
,
y
,
mel
=
batch
y_hat
=
self
.
model
(
wav
,
mel
)
if
self
.
mode
==
'RAW'
:
y_hat
=
y_hat
.
transpose
([
0
,
2
,
1
]).
unsqueeze
(
-
1
)
elif
self
.
mode
==
'MOL'
:
y_hat
=
paddle
.
cast
(
y
,
dtype
=
'float32'
)
y
=
y
.
unsqueeze
(
-
1
)
loss
=
self
.
criterion
(
y_hat
,
y
)
loss
.
backward
()
grad_norm
=
float
(
calculate_grad_norm
(
self
.
model
.
parameters
(),
norm_type
=
2
))
self
.
optimizer
.
step
()
report
(
"train/loss"
,
float
(
loss
))
report
(
"train/grad_norm"
,
float
(
grad_norm
))
losses_dict
[
"loss"
]
=
float
(
loss
)
losses_dict
[
"grad_norm"
]
=
float
(
grad_norm
)
self
.
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_dict
.
items
())
global
ITERATION
ITERATION
=
self
.
state
.
iteration
+
1
class
WaveRNNEvaluator
(
StandardEvaluator
):
def
__init__
(
self
,
model
:
Layer
,
criterion
:
Layer
,
dataloader
:
Optimizer
,
output_dir
:
Path
=
None
,
valid_generate_loader
=
None
,
config
=
None
):
super
().
__init__
(
model
,
dataloader
)
log_file
=
output_dir
/
'worker_{}.log'
.
format
(
dist
.
get_rank
())
self
.
filehandler
=
logging
.
FileHandler
(
str
(
log_file
))
logger
.
addHandler
(
self
.
filehandler
)
self
.
logger
=
logger
self
.
msg
=
""
self
.
criterion
=
criterion
self
.
valid_generate_loader
=
valid_generate_loader
self
.
config
=
config
self
.
mode
=
config
.
model
.
mode
self
.
valid_samples_dir
=
output_dir
/
"valid_samples"
self
.
valid_samples_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
def
evaluate_core
(
self
,
batch
):
self
.
msg
=
"Evaluate: "
losses_dict
=
{}
# parse batch
wav
,
y
,
mel
=
batch
y_hat
=
self
.
model
(
wav
,
mel
)
if
self
.
mode
==
'RAW'
:
y_hat
=
y_hat
.
transpose
([
0
,
2
,
1
]).
unsqueeze
(
-
1
)
elif
self
.
mode
==
'MOL'
:
y_hat
=
paddle
.
cast
(
y
,
dtype
=
'float32'
)
y
=
y
.
unsqueeze
(
-
1
)
loss
=
self
.
criterion
(
y_hat
,
y
)
report
(
"eval/loss"
,
float
(
loss
))
losses_dict
[
"loss"
]
=
float
(
loss
)
self
.
iteration
=
ITERATION
if
self
.
iteration
%
self
.
config
.
gen_eval_samples_interval_steps
==
0
:
self
.
gen_valid_samples
()
self
.
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_dict
.
items
())
self
.
logger
.
info
(
self
.
msg
)
def
gen_valid_samples
(
self
):
for
i
,
(
mel
,
wav
)
in
enumerate
(
self
.
valid_generate_loader
):
if
i
>=
self
.
config
.
generate_num
:
print
(
"before break"
)
break
print
(
'
\n
| Generating: {}/{}'
.
format
(
i
+
1
,
self
.
config
.
generate_num
))
wav
=
wav
[
0
]
if
self
.
mode
==
'MOL'
:
bits
=
16
else
:
bits
=
self
.
config
.
model
.
bits
if
self
.
config
.
mu_law
and
self
.
mode
!=
'MOL'
:
wav
=
decode_mu_law
(
wav
,
2
**
bits
,
from_labels
=
True
)
else
:
wav
=
label_2_float
(
wav
,
bits
)
origin_save_path
=
self
.
valid_samples_dir
/
'{}_steps_{}_target.wav'
.
format
(
self
.
iteration
,
i
)
sf
.
write
(
origin_save_path
,
wav
.
numpy
(),
samplerate
=
self
.
config
.
fs
)
if
self
.
config
.
inference
.
gen_batched
:
batch_str
=
'gen_batched_target{}_overlap{}'
.
format
(
self
.
config
.
inference
.
target
,
self
.
config
.
inference
.
overlap
)
else
:
batch_str
=
'gen_not_batched'
gen_save_path
=
str
(
self
.
valid_samples_dir
/
'{}_steps_{}_{}.wav'
.
format
(
self
.
iteration
,
i
,
batch_str
))
# (1, C_aux, T) -> (T, C_aux)
mel
=
mel
.
squeeze
(
0
).
transpose
([
1
,
0
])
gen_sample
=
self
.
model
.
generate
(
mel
,
self
.
config
.
inference
.
gen_batched
,
self
.
config
.
inference
.
target
,
self
.
config
.
inference
.
overlap
,
self
.
config
.
mu_law
)
sf
.
write
(
gen_save_path
,
gen_sample
.
numpy
(),
samplerate
=
self
.
config
.
fs
)
paddlespeech/t2s/modules/losses.py
浏览文件 @
fb0acd40
...
...
@@ -14,6 +14,7 @@
import
math
import
librosa
import
numpy
as
np
import
paddle
from
paddle
import
nn
from
paddle.fluid.layers
import
sequence_mask
...
...
@@ -23,6 +24,145 @@ from scipy import signal
from
paddlespeech.t2s.modules.nets_utils
import
make_non_pad_mask
# Losses for WaveRNN
def
log_sum_exp
(
x
):
""" numerically stable log_sum_exp implementation that prevents overflow """
# TF ordering
axis
=
len
(
x
.
shape
)
-
1
m
=
paddle
.
max
(
x
,
axis
=
axis
)
m2
=
paddle
.
max
(
x
,
axis
=
axis
,
keepdim
=
True
)
return
m
+
paddle
.
log
(
paddle
.
sum
(
paddle
.
exp
(
x
-
m2
),
axis
=
axis
))
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
def
discretized_mix_logistic_loss
(
y_hat
,
y
,
num_classes
=
65536
,
log_scale_min
=
None
,
reduce
=
True
):
if
log_scale_min
is
None
:
log_scale_min
=
float
(
np
.
log
(
1e-14
))
y_hat
=
y_hat
.
transpose
([
0
,
2
,
1
])
assert
y_hat
.
dim
()
==
3
assert
y_hat
.
shape
[
1
]
%
3
==
0
nr_mix
=
y_hat
.
shape
[
1
]
//
3
# (B x T x C)
y_hat
=
y_hat
.
transpose
([
0
,
2
,
1
])
# unpack parameters. (B, T, num_mixtures) x 3
logit_probs
=
y_hat
[:,
:,
:
nr_mix
]
means
=
y_hat
[:,
:,
nr_mix
:
2
*
nr_mix
]
log_scales
=
paddle
.
clip
(
y_hat
[:,
:,
2
*
nr_mix
:
3
*
nr_mix
],
min
=
log_scale_min
)
# B x T x 1 -> B x T x num_mixtures
y
=
y
.
expand_as
(
means
)
centered_y
=
paddle
.
cast
(
y
,
dtype
=
paddle
.
get_default_dtype
())
-
means
inv_stdv
=
paddle
.
exp
(
-
log_scales
)
plus_in
=
inv_stdv
*
(
centered_y
+
1.
/
(
num_classes
-
1
))
cdf_plus
=
F
.
sigmoid
(
plus_in
)
min_in
=
inv_stdv
*
(
centered_y
-
1.
/
(
num_classes
-
1
))
cdf_min
=
F
.
sigmoid
(
min_in
)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
# softplus: log(1+ e^{-x})
log_cdf_plus
=
plus_in
-
F
.
softplus
(
plus_in
)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min
=
-
F
.
softplus
(
min_in
)
# probability for all other cases
cdf_delta
=
cdf_plus
-
cdf_min
mid_in
=
inv_stdv
*
centered_y
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid
=
mid_in
-
log_scales
-
2.
*
F
.
softplus
(
mid_in
)
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
# for num_classes=65536 case? 1e-7? not sure..
inner_inner_cond
=
cdf_delta
>
1e-5
inner_inner_cond
=
paddle
.
cast
(
inner_inner_cond
,
dtype
=
paddle
.
get_default_dtype
())
# inner_inner_out = inner_inner_cond * \
# paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \
# (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
inner_inner_out
=
inner_inner_cond
*
paddle
.
log
(
paddle
.
clip
(
cdf_delta
,
min
=
1e-12
))
+
(
1.
-
inner_inner_cond
)
*
(
log_pdf_mid
-
np
.
log
((
num_classes
-
1
)
/
2
))
inner_cond
=
y
>
0.999
inner_cond
=
paddle
.
cast
(
inner_cond
,
dtype
=
paddle
.
get_default_dtype
())
inner_out
=
inner_cond
*
log_one_minus_cdf_min
+
(
1.
-
inner_cond
)
*
inner_inner_out
cond
=
y
<
-
0.999
cond
=
paddle
.
cast
(
cond
,
dtype
=
paddle
.
get_default_dtype
())
log_probs
=
cond
*
log_cdf_plus
+
(
1.
-
cond
)
*
inner_out
log_probs
=
log_probs
+
F
.
log_softmax
(
logit_probs
,
-
1
)
if
reduce
:
return
-
paddle
.
mean
(
log_sum_exp
(
log_probs
))
else
:
return
-
log_sum_exp
(
log_probs
).
unsqueeze
(
-
1
)
def
sample_from_discretized_mix_logistic
(
y
,
log_scale_min
=
None
):
"""
Sample from discretized mixture of logistic distributions
Parameters
----------
y : Tensor
(B, C, T)
log_scale_min : float
Log scale minimum value
Returns
----------
Tensor
sample in range of [-1, 1].
"""
if
log_scale_min
is
None
:
log_scale_min
=
float
(
np
.
log
(
1e-14
))
assert
y
.
shape
[
1
]
%
3
==
0
nr_mix
=
y
.
shape
[
1
]
//
3
# (B, T, C)
y
=
y
.
transpose
([
0
,
2
,
1
])
logit_probs
=
y
[:,
:,
:
nr_mix
]
# sample mixture indicator from softmax
temp
=
paddle
.
uniform
(
logit_probs
.
shape
,
dtype
=
logit_probs
.
dtype
,
min
=
1e-5
,
max
=
1.0
-
1e-5
)
temp
=
logit_probs
-
paddle
.
log
(
-
paddle
.
log
(
temp
))
argmax
=
paddle
.
argmax
(
temp
,
axis
=-
1
)
# (B, T) -> (B, T, nr_mix)
one_hot
=
F
.
one_hot
(
argmax
,
nr_mix
)
one_hot
=
paddle
.
cast
(
one_hot
,
dtype
=
paddle
.
get_default_dtype
())
# select logistic parameters
means
=
paddle
.
sum
(
y
[:,
:,
nr_mix
:
2
*
nr_mix
]
*
one_hot
,
axis
=-
1
)
log_scales
=
paddle
.
clip
(
paddle
.
sum
(
y
[:,
:,
2
*
nr_mix
:
3
*
nr_mix
]
*
one_hot
,
axis
=-
1
),
min
=
log_scale_min
)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u
=
paddle
.
uniform
(
means
.
shape
,
min
=
1e-5
,
max
=
1.0
-
1e-5
)
x
=
means
+
paddle
.
exp
(
log_scales
)
*
(
paddle
.
log
(
u
)
-
paddle
.
log
(
1.
-
u
))
x
=
paddle
.
clip
(
x
,
min
=-
1.
,
max
=-
1.
)
return
x
# Loss for new Tacotron2
class
GuidedAttentionLoss
(
nn
.
Layer
):
"""Guided attention loss function module.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录