Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9125d71a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9125d71a
编写于
10月 29, 2021
作者:
小湉湉
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pwg inference
上级
36d60a71
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
155 addition
and
15 deletion
+155
-15
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+0
-1
examples/csmsc/voc3/conf/default.yaml
examples/csmsc/voc3/conf/default.yaml
+1
-1
examples/csmsc/voc3/conf/use_tanh.yaml
examples/csmsc/voc3/conf/use_tanh.yaml
+139
-0
parakeet/models/melgan/melgan.py
parakeet/models/melgan/melgan.py
+4
-5
parakeet/models/parallel_wavegan/parallel_wavegan.py
parakeet/models/parallel_wavegan/parallel_wavegan.py
+7
-7
parakeet/modules/residual_stack.py
parakeet/modules/residual_stack.py
+4
-1
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
9125d71a
...
...
@@ -189,7 +189,6 @@ class DeepSpeech2Trainer(Trainer):
self
.
lr_scheduler
=
lr_scheduler
logger
.
info
(
"Setup optimizer/lr_scheduler!"
)
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
.
defrost
()
...
...
examples/csmsc/voc3/conf/default.yaml
浏览文件 @
9125d71a
...
...
@@ -88,7 +88,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss.
###########################################################
batch_size
:
64
# Batch size.
batch_max_steps
:
16200
# Length of each audio in batch. Make sure dividable by hop_size.
num_workers
:
4
# Number of workers in DataLoader.
num_workers
:
2
# Number of workers in DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
...
...
examples/csmsc/voc3/conf/use_tanh.yaml
0 → 100644
浏览文件 @
9125d71a
# This is the hyperparameter configuration file for MelGAN.
# Please make sure this is adjusted for the CSMSC dataset. If you want to
# apply to the other dataset, you might need to carefully change some parameters.
# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V.
# This configuration is based on full-band MelGAN but the hop size and sampling
# rate is different from the paper (16kHz vs 24kHz). The number of iteraions
# is not shown in the paper so currently we train 1M iterations (not sure enough
# to converge). The optimizer setting is based on @dathudeptrai advice.
# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs
:
24000
# Sampling rate.
n_fft
:
2048
# FFT size. (in samples)
n_shift
:
300
# Hop size. (in samples)
win_length
:
1200
# Window length. (in samples)
# If set to null, it will be the same as fft_size.
window
:
"
hann"
# Window function.
n_mels
:
80
# Number of mel basis.
fmin
:
80
# Minimum freq in mel basis calculation. (Hz)
fmax
:
7600
# Maximum frequency in mel basis calculation. (Hz)
###########################################################
# GENERATOR NETWORK ARCHITECTURE SETTING #
###########################################################
generator_params
:
in_channels
:
80
# Number of input channels.
out_channels
:
4
# Number of output channels.
kernel_size
:
7
# Kernel size of initial and final conv layers.
channels
:
384
# Initial number of channels for conv layers.
upsample_scales
:
[
5
,
5
,
3
]
# List of Upsampling scales.
stack_kernel_size
:
3
# Kernel size of dilated conv layers in residual stack.
stacks
:
4
# Number of stacks in a single residual stack module.
use_weight_norm
:
True
# Whether to use weight normalization.
use_causal_conv
:
False
# Whether to use causal convolution.
use_final_nonlinear_activation
:
True
# If True, spectral_convergence_loss and sub_spectral_convergence_loss will be too large (eg.30)
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
###########################################################
discriminator_params
:
in_channels
:
1
# Number of input channels.
out_channels
:
1
# Number of output channels.
scales
:
3
# Number of multi-scales.
downsample_pooling
:
"
AvgPool1D"
# Pooling type for the input downsampling.
downsample_pooling_params
:
# Parameters of the above pooling function.
kernel_size
:
4
stride
:
2
padding
:
1
exclusive
:
True
kernel_sizes
:
[
5
,
3
]
# List of kernel size.
channels
:
16
# Number of channels of the initial conv layer.
max_downsample_channels
:
512
# Maximum number of channels of downsampling layers.
downsample_scales
:
[
4
,
4
,
4
]
# List of downsampling scales.
nonlinear_activation
:
"
LeakyReLU"
# Nonlinear activation function.
nonlinear_activation_params
:
# Parameters of nonlinear activation function.
negative_slope
:
0.2
use_weight_norm
:
True
# Whether to use weight norm.
###########################################################
# STFT LOSS SETTING #
###########################################################
use_stft_loss
:
true
stft_loss_params
:
fft_sizes
:
[
1024
,
2048
,
512
]
# List of FFT size for STFT-based loss.
hop_sizes
:
[
120
,
240
,
50
]
# List of hop size for STFT-based loss
win_lengths
:
[
600
,
1200
,
240
]
# List of window length for STFT-based loss.
window
:
"
hann"
# Window function for STFT-based loss
use_subband_stft_loss
:
true
subband_stft_loss_params
:
fft_sizes
:
[
384
,
683
,
171
]
# List of FFT size for STFT-based loss.
hop_sizes
:
[
30
,
60
,
10
]
# List of hop size for STFT-based loss
win_lengths
:
[
150
,
300
,
60
]
# List of window length for STFT-based loss.
window
:
"
hann"
# Window function for STFT-based loss
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
use_feat_match_loss
:
false
# Whether to use feature matching loss.
lambda_adv
:
2.5
# Loss balancing coefficient for adversarial loss.
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size
:
64
# Batch size.
batch_max_steps
:
16200
# Length of each audio in batch. Make sure dividable by hop_size.
num_workers
:
2
# Number of workers in DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params
:
epsilon
:
1.0e-7
# Generator's epsilon.
weight_decay
:
0.0
# Generator's weight decay coefficient.
generator_grad_norm
:
-1
# Generator's gradient norm.
generator_scheduler_params
:
learning_rate
:
1.0e-3
# Generator's learning rate.
gamma
:
0.5
# Generator's scheduler gamma.
milestones
:
# At each milestone, lr will be multiplied by gamma.
-
100000
-
200000
-
300000
-
400000
-
500000
-
600000
discriminator_optimizer_params
:
epsilon
:
1.0e-7
# Discriminator's epsilon.
weight_decay
:
0.0
# Discriminator's weight decay coefficient.
discriminator_grad_norm
:
-1
# Discriminator's gradient norm.
discriminator_scheduler_params
:
learning_rate
:
1.0e-3
# Discriminator's learning rate.
gamma
:
0.5
# Discriminator's scheduler gamma.
milestones
:
# At each milestone, lr will be multiplied by gamma.
-
100000
-
200000
-
300000
-
400000
-
500000
-
600000
###########################################################
# INTERVAL SETTING #
###########################################################
discriminator_train_start_steps
:
200000
# Number of steps to start to train discriminator.
train_max_steps
:
1000000
# Number of training steps.
save_interval_steps
:
50000
# Interval steps to save checkpoint.
eval_interval_steps
:
1000
# Interval steps to evaluate the network.
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots
:
10
# max number of snapshots to keep while training
seed
:
42
# random seed for paddle, random, and np.random
\ No newline at end of file
parakeet/models/melgan/melgan.py
浏览文件 @
9125d71a
...
...
@@ -19,7 +19,6 @@ from typing import List
import
numpy
as
np
import
paddle
from
paddle
import
nn
from
paddle.fluid.layers
import
Normal
from
parakeet.modules.causal_conv
import
CausalConv1D
from
parakeet.modules.causal_conv
import
CausalConv1DTranspose
...
...
@@ -238,7 +237,7 @@ class MelGANGenerator(nn.Layer):
"""
# 定义参数为float的正态分布。
dist
=
Normal
(
loc
=
0.0
,
scale
=
0.02
)
dist
=
paddle
.
distribution
.
Normal
(
loc
=
0.0
,
scale
=
0.02
)
def
_reset_parameters
(
m
):
if
isinstance
(
m
,
nn
.
Conv1D
)
or
isinstance
(
m
,
nn
.
Conv1DTranspose
):
...
...
@@ -290,8 +289,8 @@ class MelGANDiscriminator(nn.Layer):
"""Initilize MelGAN discriminator module.
Parameters
----------
in_channels :
int):
Number of input channels.
in_channels :
int
Number of input channels.
out_channels : int
Number of output channels.
kernel_sizes : List[int]
...
...
@@ -531,7 +530,7 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
"""
# 定义参数为float的正态分布。
dist
=
Normal
(
loc
=
0.0
,
scale
=
0.02
)
dist
=
paddle
.
distribution
.
Normal
(
loc
=
0.0
,
scale
=
0.02
)
def
_reset_parameters
(
m
):
if
isinstance
(
m
,
nn
.
Conv1D
)
or
isinstance
(
m
,
nn
.
Conv1DTranspose
):
...
...
parakeet/models/parallel_wavegan/parallel_wavegan.py
浏览文件 @
9125d71a
...
...
@@ -495,25 +495,25 @@ class PWGGenerator(nn.Layer):
self
.
apply
(
_remove_weight_norm
)
def
inference
(
self
,
c
):
def
inference
(
self
,
c
=
None
):
"""Waveform generation. This function is used for single instance
inference.
Parameters
----------
c : Tensor
c : Tensor
, optional
Shape (T', C_aux), the auxiliary input, by default None
x : Tensor, optional
Shape (T, C_in), the noise waveform, by default None
If not provided, a sample is drawn from a gaussian distribution.
Returns
-------
Tensor
Shape (T, C_out), the generated waveform
"""
#
a sample is drawn from a gaussian distribution.
#
when to static, can not input x, see https://github.com/PaddlePaddle/Parakeet/pull/132/files
x
=
paddle
.
randn
(
[
1
,
self
.
in_channels
,
paddle
.
shape
(
c
)[
0
]
*
self
.
upsample_factor
])
# pseudo batch
c
=
paddle
.
transpose
(
c
,
[
1
,
0
]).
unsqueeze
(
0
)
c
=
paddle
.
transpose
(
c
,
[
1
,
0
]).
unsqueeze
(
0
)
# pseudo batch
c
=
nn
.
Pad1D
(
self
.
aux_context_window
,
mode
=
'replicate'
)(
c
)
out
=
self
(
x
,
c
).
squeeze
(
0
).
transpose
([
1
,
0
])
return
out
...
...
parakeet/modules/residual_stack.py
浏览文件 @
9125d71a
...
...
@@ -106,4 +106,7 @@ class ResidualStack(nn.Layer):
Tensor
Output tensor (B, chennels, T).
"""
return
self
.
stack
(
c
)
+
self
.
skip_layer
(
c
)
stack_output
=
self
.
stack
(
c
)
skip_layer_output
=
self
.
skip_layer
(
c
)
out
=
stack_output
+
skip_layer_output
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录