Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0fc79f47
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0fc79f47
编写于
3月 30, 2022
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CNNDecoder, test=tts
上级
b5315657
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
417 addition
and
10 deletion
+417
-10
examples/csmsc/tts3/conf/cnndecoder.yaml
examples/csmsc/tts3/conf/cnndecoder.yaml
+107
-0
examples/csmsc/tts3/local/synthesize_streaming.sh
examples/csmsc/tts3/local/synthesize_streaming.sh
+92
-0
examples/csmsc/tts3/run_cnndecoder.sh
examples/csmsc/tts3/run_cnndecoder.sh
+48
-0
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
+37
-10
paddlespeech/t2s/modules/transformer/encoder.py
paddlespeech/t2s/modules/transformer/encoder.py
+133
-0
未找到文件。
examples/csmsc/tts3/conf/cnndecoder.yaml
0 → 100644
浏览文件 @
0fc79f47
# use CNND
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs
:
24000
# sr
n_fft
:
2048
# FFT size (samples).
n_shift
:
300
# Hop size (samples). 12.5ms
win_length
:
1200
# Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window
:
"
hann"
# Window function.
# Only used for feats_type != raw
fmin
:
80
# Minimum frequency of Mel basis.
fmax
:
7600
# Maximum frequency of Mel basis.
n_mels
:
80
# The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min
:
80
# Minimum f0 for pitch extraction.
f0max
:
400
# Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size
:
64
num_workers
:
4
###########################################################
# MODEL SETTING #
###########################################################
model
:
adim
:
384
# attention dimension
aheads
:
2
# number of attention heads
elayers
:
4
# number of encoder layers
eunits
:
1536
# number of encoder ff units
dlayers
:
4
# number of decoder layers
dunits
:
1536
# number of decoder ff units
positionwise_layer_type
:
conv1d
# type of position-wise layer
positionwise_conv_kernel_size
:
3
# kernel size of position wise conv layer
duration_predictor_layers
:
2
# number of layers of duration predictor
duration_predictor_chans
:
256
# number of channels of duration predictor
duration_predictor_kernel_size
:
3
# filter size of duration predictor
postnet_layers
:
5
# number of layers of postnset
postnet_filts
:
5
# filter size of conv layers in postnet
postnet_chans
:
256
# number of channels of conv layers in postnet
use_scaled_pos_enc
:
True
# whether to use scaled positional encoding
encoder_normalize_before
:
True
# whether to perform layer normalization before the input
decoder_normalize_before
:
True
# whether to perform layer normalization before the input
reduction_factor
:
1
# reduction factor
encoder_type
:
transformer
# encoder type
decoder_type
:
cnndecoder
# decoder type
init_type
:
xavier_uniform
# initialization type
init_enc_alpha
:
1.0
# initial value of alpha of encoder scaled position encoding
init_dec_alpha
:
1.0
# initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate
:
0.2
# dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate
:
0.2
# dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate
:
0.2
# dropout rate for transformer encoder attention layer
cnn_dec_dropout_rate
:
0.2
# dropout rate for cnn decoder layer
cnn_postnet_dropout_rate
:
0.2
cnn_postnet_resblock_kernel_sizes
:
[
256
,
256
]
# kernel sizes for residual block of cnn_postnet
cnn_postnet_kernel_size
:
5
# kernel size of cnn_postnet
cnn_decoder_embedding_dim
:
256
pitch_predictor_layers
:
5
# number of conv layers in pitch predictor
pitch_predictor_chans
:
256
# number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size
:
5
# kernel size of conv leyers in pitch predictor
pitch_predictor_dropout
:
0.5
# dropout rate in pitch predictor
pitch_embed_kernel_size
:
1
# kernel size of conv embedding layer for pitch
pitch_embed_dropout
:
0.0
# dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor
:
True
# whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers
:
2
# number of conv layers in energy predictor
energy_predictor_chans
:
256
# number of channels of conv layers in energy predictor
energy_predictor_kernel_size
:
3
# kernel size of conv leyers in energy predictor
energy_predictor_dropout
:
0.5
# dropout rate in energy predictor
energy_embed_kernel_size
:
1
# kernel size of conv embedding layer for energy
energy_embed_dropout
:
0.0
# dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor
:
False
# whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater
:
use_masking
:
True
# whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer
:
optim
:
adam
# optimizer type
learning_rate
:
0.001
# learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
1000
num_snapshots
:
5
###########################################################
# OTHER SETTING #
###########################################################
seed
:
10086
examples/csmsc/tts3/local/synthesize_streaming.sh
0 → 100755
浏览文件 @
0fc79f47
#!/bin/bash
config_path
=
$1
train_output_path
=
$2
ckpt_name
=
$3
stage
=
0
stop_stage
=
0
# pwgan
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
pwgan_csmsc
\
--voc_config
=
pwg_baker_ckpt_0.4/pwg_default.yaml
\
--voc_ckpt
=
pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz
\
--voc_stat
=
pwg_baker_ckpt_0.4/pwg_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
# for more GAN Vocoders
# multi band melgan
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
mb_melgan_csmsc
\
--voc_config
=
mb_melgan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz
\
--voc_stat
=
mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
# the pretrained models haven't release now
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
style_melgan_csmsc
\
--voc_config
=
style_melgan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz
\
--voc_stat
=
style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
# --inference_dir=${train_output_path}/inference
fi
# hifigan
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
echo
"in hifigan syn_e2e"
FLAGS_allocator_strategy
=
naive_best_fit
\
FLAGS_fraction_of_gpu_memory_to_use
=
0.01
\
python3
${
BIN_DIR
}
/../synthesize_streaming.py
\
--am
=
fastspeech2_csmsc
\
--am_config
=
${
config_path
}
\
--am_ckpt
=
${
train_output_path
}
/checkpoints/
${
ckpt_name
}
\
--am_stat
=
dump/train/speech_stats.npy
\
--voc
=
hifigan_csmsc
\
--voc_config
=
hifigan_csmsc_ckpt_0.1.1/default.yaml
\
--voc_ckpt
=
hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz
\
--voc_stat
=
hifigan_csmsc_ckpt_0.1.1/feats_stats.npy
\
--lang
=
zh
\
--text
=
${
BIN_DIR
}
/../sentences.txt
\
--output_dir
=
${
train_output_path
}
/test_e2e
\
--phones_dict
=
dump/phone_id_map.txt
\
--inference_dir
=
${
train_output_path
}
/inference
fi
examples/csmsc/tts3/run_cnndecoder.sh
0 → 100755
浏览文件 @
0fc79f47
#!/bin/bash
set
-e
source
path.sh
gpus
=
0,1
stage
=
0
stop_stage
=
100
conf_path
=
conf/cnndecoder.yaml
train_output_path
=
exp/cnndecoder
ckpt_name
=
snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
./local/preprocess.sh
${
conf_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
train_output_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize_e2e.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# inference with static model
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/inference.sh
${
train_output_path
}
||
exit
-1
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/synthesize_streaming.sh
${
conf_path
}
${
train_output_path
}
${
ckpt_name
}
||
exit
-1
fi
paddlespeech/t2s/models/fastspeech2/fastspeech2.py
浏览文件 @
0fc79f47
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from espnet(https://github.com/espnet/espnet)
"""Fastspeech2 related modules for paddle"""
"""Fastspeech2 related modules for paddle"""
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
Sequence
from
typing
import
Sequence
from
typing
import
Tuple
from
typing
import
Tuple
from
typing
import
Union
from
typing
import
Union
...
@@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
...
@@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
from
paddlespeech.t2s.modules.predictor.length_regulator
import
LengthRegulator
from
paddlespeech.t2s.modules.predictor.length_regulator
import
LengthRegulator
from
paddlespeech.t2s.modules.predictor.variance_predictor
import
VariancePredictor
from
paddlespeech.t2s.modules.predictor.variance_predictor
import
VariancePredictor
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.tacotron2.decoder
import
Postnet
from
paddlespeech.t2s.modules.transformer.encoder
import
CNNDecoder
from
paddlespeech.t2s.modules.transformer.encoder
import
CNNPostnet
from
paddlespeech.t2s.modules.transformer.encoder
import
ConformerEncoder
from
paddlespeech.t2s.modules.transformer.encoder
import
ConformerEncoder
from
paddlespeech.t2s.modules.transformer.encoder
import
TransformerEncoder
from
paddlespeech.t2s.modules.transformer.encoder
import
TransformerEncoder
...
@@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
...
@@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
zero_triu
:
bool
=
False
,
zero_triu
:
bool
=
False
,
conformer_enc_kernel_size
:
int
=
7
,
conformer_enc_kernel_size
:
int
=
7
,
conformer_dec_kernel_size
:
int
=
31
,
conformer_dec_kernel_size
:
int
=
31
,
# for CNN Decoder
cnn_dec_dropout_rate
:
float
=
0.2
,
cnn_postnet_dropout_rate
:
float
=
0.2
,
cnn_postnet_resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
cnn_postnet_kernel_size
:
int
=
5
,
cnn_decoder_embedding_dim
:
int
=
256
,
# duration predictor
# duration predictor
duration_predictor_layers
:
int
=
2
,
duration_predictor_layers
:
int
=
2
,
duration_predictor_chans
:
int
=
384
,
duration_predictor_chans
:
int
=
384
,
...
@@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
...
@@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
activation_type
=
conformer_activation_type
,
activation_type
=
conformer_activation_type
,
use_cnn_module
=
use_cnn_in_conformer
,
use_cnn_module
=
use_cnn_in_conformer
,
cnn_module_kernel
=
conformer_dec_kernel_size
,
)
cnn_module_kernel
=
conformer_dec_kernel_size
,
)
elif
decoder_type
==
'cnndecoder'
:
self
.
decoder
=
CNNDecoder
(
emb_dim
=
adim
,
odim
=
odim
,
kernel_size
=
cnn_postnet_kernel_size
,
dropout_rate
=
cnn_dec_dropout_rate
,
resblock_kernel_sizes
=
cnn_postnet_resblock_kernel_sizes
)
else
:
else
:
raise
ValueError
(
f
"
{
decoder_type
}
is not supported."
)
raise
ValueError
(
f
"
{
decoder_type
}
is not supported."
)
...
@@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer):
...
@@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer):
self
.
feat_out
=
nn
.
Linear
(
adim
,
odim
*
reduction_factor
)
self
.
feat_out
=
nn
.
Linear
(
adim
,
odim
*
reduction_factor
)
# define postnet
# define postnet
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
if
decoder_type
==
'cnndecoder'
:
idim
=
idim
,
self
.
postnet
=
CNNPostnet
(
odim
=
odim
,
odim
=
odim
,
n_layers
=
postnet_layers
,
kernel_size
=
cnn_postnet_kernel_size
,
n_chans
=
postnet_chans
,
dropout_rate
=
cnn_postnet_dropout_rate
,
n_filts
=
postnet_filts
,
resblock_kernel_sizes
=
cnn_postnet_resblock_kernel_sizes
)
use_batch_norm
=
use_batch_norm
,
else
:
dropout_rate
=
postnet_dropout_rate
,
))
self
.
postnet
=
(
None
if
postnet_layers
==
0
else
Postnet
(
idim
=
idim
,
odim
=
odim
,
n_layers
=
postnet_layers
,
n_chans
=
postnet_chans
,
n_filts
=
postnet_filts
,
use_batch_norm
=
use_batch_norm
,
dropout_rate
=
postnet_dropout_rate
,
))
nn
.
initializer
.
set_global_initializer
(
None
)
nn
.
initializer
.
set_global_initializer
(
None
)
...
@@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer):
...
@@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer):
[
olen
//
self
.
reduction_factor
for
olen
in
olens
.
numpy
()])
[
olen
//
self
.
reduction_factor
for
olen
in
olens
.
numpy
()])
else
:
else
:
olens_in
=
olens
olens_in
=
olens
# (B, 1, T)
h_masks
=
self
.
_source_mask
(
olens_in
)
h_masks
=
self
.
_source_mask
(
olens_in
)
else
:
else
:
h_masks
=
None
h_masks
=
None
...
@@ -569,8 +593,11 @@ class FastSpeech2(nn.Layer):
...
@@ -569,8 +593,11 @@ class FastSpeech2(nn.Layer):
zs
,
_
=
self
.
decoder
(
hs
,
h_masks
)
zs
,
_
=
self
.
decoder
(
hs
,
h_masks
)
# (B, Lmax, odim)
# (B, Lmax, odim)
before_outs
=
self
.
feat_out
(
zs
).
reshape
(
if
self
.
decoder_type
==
'cnndecoder'
:
(
paddle
.
shape
(
zs
)[
0
],
-
1
,
self
.
odim
))
before_outs
=
zs
else
:
before_outs
=
self
.
feat_out
(
zs
).
reshape
(
(
paddle
.
shape
(
zs
)[
0
],
-
1
,
self
.
odim
))
# postnet -> (B, Lmax//r * r, odim)
# postnet -> (B, Lmax//r * r, odim)
if
self
.
postnet
is
None
:
if
self
.
postnet
is
None
:
...
...
paddlespeech/t2s/modules/transformer/encoder.py
浏览文件 @
0fc79f47
...
@@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder):
...
@@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder):
if
self
.
intermediate_layers
is
not
None
:
if
self
.
intermediate_layers
is
not
None
:
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
return
xs
,
masks
class
Conv1dResidualBlock
(
nn
.
Layer
):
"""
Special module for simplified version of Encoder class.
"""
def
__init__
(
self
,
idim
:
int
=
256
,
odim
:
int
=
256
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
):
super
().
__init__
()
self
.
main_block
=
nn
.
Sequential
(
nn
.
Conv1D
(
idim
,
odim
,
kernel_size
=
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
ReLU
(),
nn
.
BatchNorm1D
(
odim
),
nn
.
Dropout
(
p
=
dropout_rate
))
self
.
conv1d_residual
=
nn
.
Conv1D
(
idim
,
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, idim, T).
Returns:
Tensor: Output tensor (#batch, odim, T).
"""
outputs
=
self
.
main_block
(
xs
)
outputs
=
self
.
conv1d_residual
(
xs
)
+
outputs
return
outputs
class
CNNDecoder
(
nn
.
Layer
):
"""
Much simplified decoder than the original one with Prenet.
"""
def
__init__
(
self
,
emb_dim
:
int
=
256
,
odim
:
int
=
80
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
,
resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
):
super
().
__init__
()
input_shape
=
emb_dim
out_sizes
=
resblock_kernel_sizes
out_sizes
.
append
(
out_sizes
[
-
1
])
in_sizes
=
[
input_shape
]
+
out_sizes
[:
-
1
]
self
.
residual_blocks
=
nn
.
LayerList
([
Conv1dResidualBlock
(
idim
=
in_channels
,
odim
=
out_channels
,
kernel_size
=
kernel_size
,
dropout_rate
=
dropout_rate
,
)
for
in_channels
,
out_channels
in
zip
(
in_sizes
,
out_sizes
)
])
self
.
conv1d
=
nn
.
Conv1D
(
in_channels
=
out_sizes
[
-
1
],
out_channels
=
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
,
masks
=
None
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, time, idim).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, time, odim).
"""
# print("input.shape in CNNDecoder:",xs.shape)
# exchange the temporal dimension and the feature dimension
xs
=
xs
.
transpose
([
0
,
2
,
1
])
if
masks
is
not
None
:
xs
=
xs
*
masks
for
layer
in
self
.
residual_blocks
:
outputs
=
layer
(
xs
)
if
masks
is
not
None
:
# input_mask B * 1 * T
outputs
=
outputs
*
masks
xs
=
outputs
outputs
=
self
.
conv1d
(
outputs
)
if
masks
is
not
None
:
outputs
=
outputs
*
masks
outputs
=
outputs
.
transpose
([
0
,
2
,
1
])
# print("outputs.shape in CNNDecoder:",outputs.shape)
return
outputs
,
masks
class
CNNPostnet
(
nn
.
Layer
):
def
__init__
(
self
,
odim
:
int
=
80
,
kernel_size
:
int
=
5
,
dropout_rate
:
float
=
0.2
,
resblock_kernel_sizes
:
List
[
int
]
=
[
256
,
256
],
):
super
().
__init__
()
out_sizes
=
resblock_kernel_sizes
in_sizes
=
[
odim
]
+
out_sizes
[:
-
1
]
self
.
residual_blocks
=
nn
.
LayerList
([
Conv1dResidualBlock
(
idim
=
in_channels
,
odim
=
out_channels
,
kernel_size
=
kernel_size
,
dropout_rate
=
dropout_rate
)
for
in_channels
,
out_channels
in
zip
(
in_sizes
,
out_sizes
)
])
self
.
conv1d
=
nn
.
Conv1D
(
in_channels
=
out_sizes
[
-
1
],
out_channels
=
odim
,
kernel_size
=
1
)
def
forward
(
self
,
xs
,
masks
=
None
):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, odim, time).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, odim, time).
"""
# print("xs.shape in CNNPostnet:",xs.shape)
for
layer
in
self
.
residual_blocks
:
outputs
=
layer
(
xs
)
if
masks
is
not
None
:
# input_mask B * 1 * T
outputs
=
outputs
*
masks
xs
=
outputs
outputs
=
self
.
conv1d
(
outputs
)
if
masks
is
not
None
:
outputs
=
outputs
*
masks
# print("outputs.shape in CNNPostnet:",outputs.shape)
return
outputs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录