提交 4e714887 编写于 作者: L liuyibing01

Merge branch 'master' into 'master'

add clarinet

See merge request !27
# Clarinet
Paddle implementation of clarinet in dynamic graph, a convolutional network based vocoder. The implementation is based on the paper [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](arxiv.org/abs/1807.07281).
## Dataset
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
## Project Structure
```text
├── data.py data_processing
├── configs/ (example) configuration file
├── synthesis.py script to synthesize waveform from mel_spectrogram
├── train.py script to train a model
└── utils.py utility functions
```
## Train
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
train a clarinet model with LJspeech and a trained wavenet model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--resume RESUME checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
```
1. `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
2. `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
3. `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
4. `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
├── states # audio files generated at validation
└── log # tensorboard log
```
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
6. `--wavenet` is the path of the wavenet checkpoint to load. if you do not specify `--resume`, then this must be provided.
Before you start training a clarinet model, you should have trained a wavenet model with single gaussian as output distribution. Make sure the config for teacher matches that for the trained model.
example script:
```bash
python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher
```
You can monitor training log via tensorboard, using the script below.
```bash
cd experiment/log
tensorboard --logdir=.
```
## Synthesis
```text
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
checkpoint output
train a clarinet model with LJspeech and a trained wavenet model.
positional arguments:
checkpoint checkpoint to load from.
output path to save student.
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--data DATA path of LJspeech dataset.
```
1. `--config` is the configuration file to use. You should use the same configuration with which you train you model.
2. `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
3. `checkpoint` is the checkpoint to load.
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
```
data:
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
hop_length: 256
win_length: 1024
n_fft: 2048
n_mels: 80
valid_size: 16
conditioner:
upsampling_factors: [16, 16]
teacher:
n_loop: 10
n_layer: 3
filter_size: 2
residual_channels: 128
loss_type: "mog"
output_dim: 3
log_scale_min: -9
student:
n_loops: [10, 10, 10, 10, 10, 10]
n_layers: [1, 1, 1, 1, 1, 1]
filter_size: 3
residual_channels: 64
log_scale_min: -7
stft:
n_fft: 2048
win_length: 1024
hop_length: 256
loss:
lmd: 4
train:
learning_rate: 0.0005
anneal_rate: 0.5
anneal_interval: 200000
gradient_max_norm: 100.0
checkpoint_interval: 1000
eval_interval: 1000
max_iterations: 2000000
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
......
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
......
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
......
......@@ -56,7 +56,7 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
sf.write(wav_np, path, samplerate=sample_rate)
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
......
......@@ -134,7 +134,7 @@ class SliceDataset(DatasetMixin):
format(len(order), len(dataset)))
self._order = order
def len(self):
def __len__(self):
return self._size
def get_example(self, i):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .net import *
from .parallel_wavenet import *
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import numpy as np
from scipy import signal
from tqdm import trange
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
import paddle.fluid.initializer as I
import paddle.fluid.layers.distributions as D
from parakeet.modules.weight_norm import Conv2DTranspose
from parakeet.models.wavenet import crop, WaveNet, UpsampleNet
from parakeet.models.clarinet.parallel_wavenet import ParallelWaveNet
from parakeet.models.clarinet.utils import conv2d
# Gaussian IAF model
class Clarinet(dg.Layer):
def __init__(self,
encoder,
teacher,
student,
stft,
min_log_scale=-6.0,
lmd=4.0):
super(Clarinet, self).__init__()
self.lmd = lmd
self.encoder = encoder
self.teacher = teacher
self.student = student
self.min_log_scale = min_log_scale
self.stft = stft
def forward(self, audio, mel, audio_start, clip_kl=True):
"""Compute loss for a distill model
Arguments:
audio {Variable} -- shape(batch_size, time_steps), target waveform.
mel {Variable} -- shape(batch_size, condition_dim, time_steps // hop_length), original mel spectrogram, not upsampled yet.
audio_starts {Variable} -- shape(batch_size, ), the index of the start sample.
clip_kl (bool) -- whether to clip kl divergence if it is greater than 10.0.
Returns:
Variable -- shape(1,), loss
"""
batch_size, audio_length = audio.shape # audio clip's length
z = F.gaussian_random(audio.shape)
condition = self.encoder(mel) # (B, C, T)
condition_slice = crop(condition, audio_start, audio_length)
x, s_means, s_scales = self.student(z, condition_slice) # all [0: T]
s_means = s_means[:, 1:] # (B, T-1), time steps [1: T]
s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T]
s_clipped_scales = F.clip(s_scales, self.min_log_scale, 100.)
# teacher outputs single gaussian
y = self.teacher(x[:, :-1], condition_slice[:, :, 1:])
_, t_means, t_scales = F.split(y, 3, -1) # time steps [1: T]
t_means = F.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T]
t_scales = F.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T]
t_clipped_scales = F.clip(t_scales, self.min_log_scale, 100.)
s_distribution = D.Normal(s_means, F.exp(s_clipped_scales))
t_distribution = D.Normal(t_means, F.exp(t_clipped_scales))
# kl divergence loss, so we only need to sample once? no MC
kl = s_distribution.kl_divergence(t_distribution)
if clip_kl:
kl = F.clip(kl, -100., 10.)
# context size dropped
kl = F.reduce_mean(kl[:, self.teacher.context_size:])
# major diff here
regularization = F.mse_loss(t_scales[:, self.teacher.context_size:],
s_scales[:, self.teacher.context_size:])
# introduce information from real target
spectrogram_frame_loss = F.mse_loss(
self.stft.magnitude(audio), self.stft.magnitude(x))
loss = kl + self.lmd * regularization + spectrogram_frame_loss
loss_dict = {
"loss": loss,
"kl_divergence": kl,
"regularization": regularization,
"stft_loss": spectrogram_frame_loss
}
return loss_dict
@dg.no_grad
def synthesis(self, mel):
"""Synthesize waveform conditioned on the mel spectrogram.
Arguments:
mel {Variable} -- shape(batch_size, frequqncy_bands, frames)
Returns:
Variable -- shape(batch_size, frames * upsample_factor)
"""
condition = self.encoder(mel)
samples_shape = (condition.shape[0], condition.shape[-1])
z = F.gaussian_random(samples_shape)
x, s_means, s_scales = self.student(z, condition)
return x
class STFT(dg.Layer):
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
super(STFT, self).__init__()
self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2
self.n_fft = n_fft
# calculate window
window = signal.get_window(window, win_length)
if n_fft != win_length:
pad = (n_fft - win_length) // 2
window = np.pad(window, ((pad, pad), ), 'constant')
# calculate weights
r = np.arange(0, n_fft)
M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
w_real = np.reshape(window *
np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft)).astype("float32")
w_imag = np.reshape(window *
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
(self.n_bin, 1, 1, self.n_fft)).astype("float32")
w = np.concatenate([w_real, w_imag], axis=0)
self.weight = dg.to_variable(w)
def forward(self, x):
# x(batch_size, time_steps)
# pad it first with reflect mode
pad_start = F.reverse(x[:, 1:1 + self.n_fft // 2], axis=1)
pad_stop = F.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=1)
x = F.concat([pad_start, x, pad_stop], axis=-1)
# to BC1T, C=1
x = F.unsqueeze(x, axes=[1, 2])
out = conv2d(x, self.weight, stride=(1, self.hop_length))
real, imag = F.split(out, 2, dim=1) # BC1T
return real, imag
def power(self, x):
real, imag = self(x)
power = real**2 + imag**2
return power
def magnitude(self, x):
power = self.power(x)
magnitude = F.sqrt(power)
return magnitude
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
import itertools
import numpy as np
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
import paddle.fluid.initializer as I
import paddle.fluid.layers.distributions as D
from parakeet.modules.weight_norm import Linear, Conv1D, Conv1DCell, Conv2DTranspose
from parakeet.models.wavenet import WaveNet
class ParallelWaveNet(dg.Layer):
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
filter_size):
super(ParallelWaveNet, self).__init__()
self.flows = dg.LayerList()
for n_loop, n_layer in zip(n_loops, n_layers):
# teacher's log_scale_min does not matter herem, -100 is a dummy value
self.flows.append(
WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim,
filter_size, "mog", -100.0))
def forward(self, z, condition=None):
"""Inverse Autoregressive Flow. Several wavenets.
Arguments:
z {Variable} -- shape(batch_size, time_steps), hidden variable, sampled from a standard normal distribution.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), condition, basically upsampled mel spectrogram. (default: {None})
Returns:
Variable -- shape(batch_size, time_steps), transformed z.
Variable -- shape(batch_size, time_steps), output distribution's mu.
Variable -- shape(batch_size, time_steps), output distribution's log_std.
"""
for i, flow in enumerate(self.flows):
theta = flow(z, condition) # w, mu, log_std [0: T]
w, mu, log_std = F.split(theta, 3, dim=-1) # (B, T, 1) for each
mu = F.squeeze(mu, [-1]) #[0: T]
log_std = F.squeeze(log_std, [-1]) #[0: T]
z = z * F.exp(log_std) + mu #[0: T]
if i == 0:
out_mu = mu
out_log_std = log_std
else:
out_mu = out_mu * F.exp(log_std) + mu
out_log_std += log_std
return z, out_mu, out_log_std
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
from paddle.fluid.core import ops
@fluid.framework.dygraph_only
def conv2d(input,
weight,
stride=(1, 1),
padding=((0, 0), (0, 0)),
dilation=(1, 1),
groups=1,
use_cudnn=True,
data_format="NCHW"):
padding = tuple(pad for pad_dim in padding for pad in pad_dim)
inputs = {
'Input': [input],
'Filter': [weight],
}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": "EXPLICIT",
"data_format": data_format,
}
outputs = ops.conv2d(inputs, attrs)
out = outputs["Output"][0]
return out
\ No newline at end of file
......@@ -57,7 +57,7 @@ class UpsampleNet(dg.Layer):
"""
def __init__(self, upscale_factors=[16, 16]):
super().__init__()
super(UpsampleNet, self).__init__()
self.upscale_factors = list(upscale_factors)
self.upsample_convs = dg.LayerList()
for i, factor in enumerate(upscale_factors):
......@@ -92,7 +92,7 @@ class UpsampleNet(dg.Layer):
# AutoRegressive Model
class ConditionalWavenet(dg.Layer):
def __init__(self, encoder: UpsampleNet, decoder: WaveNet):
super().__init__()
super(ConditionalWavenet, self).__init__()
self.encoder = encoder
self.decoder = decoder
......
......@@ -39,7 +39,7 @@ def dequantize(quantized, n_bands):
class ResidualBlock(dg.Layer):
def __init__(self, residual_channels, condition_dim, filter_size,
dilation):
super().__init__()
super(ResidualBlock, self).__init__()
dilated_channels = 2 * residual_channels
# following clarinet's implementation, we do not have parametric residual
# & skip connection.
......@@ -135,7 +135,7 @@ class ResidualBlock(dg.Layer):
class ResidualNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
filter_size):
super().__init__()
super(ResidualNet, self).__init__()
# double the dilation at each layer in a loop(n_loop layers)
dilations = [2**i for i in range(n_loop)] * n_layer
self.context_size = 1 + sum(dilations)
......@@ -198,7 +198,7 @@ class ResidualNet(dg.Layer):
class WaveNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, output_dim,
condition_dim, filter_size, loss_type, log_scale_min):
super().__init__()
super(WaveNet, self).__init__()
if loss_type not in ["softmax", "mog"]:
raise ValueError("loss_type {} is not supported".format(loss_type))
if loss_type == "softmax":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册