提交 98841ee4 编写于 作者: K Kexin Zhao

clean code

上级 b15c3134
...@@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler): ...@@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler):
return self.num_samples return self.num_samples
class DistributedSampler(Sampler):
def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
self.dataset_size = dataset_size
self.num_trainers = num_trainers
self.rank = rank
self.num_samples = int(np.ceil(dataset_size / num_trainers))
self.total_size = self.num_samples * num_trainers
assert self.total_size >= self.dataset_size
self.shuffle = shuffle
def __iter__(self):
indices = list(range(self.dataset_size))
if self.shuffle:
random.shuffle(indices)
# Append extra samples to make it evenly distributed on all trainers.
indices += indices[:(self.total_size - self.dataset_size)]
assert len(indices) == self.total_size
# Subset samples for each trainer.
indices = indices[self.rank:self.total_size:self.num_trainers]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
class BatchSampler(Sampler): class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices. r"""Wraps another sampler to yield a mini-batch of indices.
Args: Args:
...@@ -206,4 +235,4 @@ class BatchSampler(Sampler): ...@@ -206,4 +235,4 @@ class BatchSampler(Sampler):
if self.drop_last: if self.drop_last:
return len(self.sampler) // self.batch_size return len(self.sampler) // self.batch_size
else: else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size return (len(self.sampler) + self.batch_size - 1) // self.batch_size
\ No newline at end of file
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: mix-gaussian-pdf
num_mixtures: 10
log_scale_min: -9.0
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: softmax
num_channels: 2048
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5
import math
import os
import random import random
import librosa import librosa
...@@ -9,7 +7,7 @@ from paddle import fluid ...@@ -9,7 +7,7 @@ from paddle import fluid
import utils import utils
from parakeet.datasets import ljspeech from parakeet.datasets import ljspeech
from parakeet.data import dataset from parakeet.data import dataset
from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler from parakeet.data.sampler import DistributedSampler, BatchSampler
from parakeet.data.datacargo import DataCargo from parakeet.data.datacargo import DataCargo
...@@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech): ...@@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech):
self.fft_window_shift = config.fft_window_shift self.fft_window_shift = config.fft_window_shift
# Calculate context frames. # Calculate context frames.
frames_per_second = config.sample_rate // self.fft_window_shift frames_per_second = config.sample_rate // self.fft_window_shift
train_clip_frames = int(math.ceil( train_clip_frames = int(np.ceil(
config.train_clip_second * frames_per_second)) config.train_clip_second * frames_per_second))
context_frames = config.context_size // self.fft_window_shift context_frames = config.context_size // self.fft_window_shift
self.num_frames = train_clip_frames + context_frames self.num_frames = train_clip_frames + context_frames
...@@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech): ...@@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech):
assert loaded_sr == sr assert loaded_sr == sr
# Pad audio to the right size. # Pad audio to the right size.
frames = math.ceil(float(audio.size) / fft_window_shift) frames = int(np.ceil(float(audio.size) / fft_window_shift))
fft_padding = (fft_size - fft_window_shift) // 2 fft_padding = (fft_size - fft_window_shift) // 2
desired_length = frames * fft_window_shift + fft_padding * 2 desired_length = frames * fft_window_shift + fft_padding * 2
pad_amount = (desired_length - audio.size) // 2 pad_amount = (desired_length - audio.size) // 2
...@@ -125,35 +123,6 @@ class Subset(dataset.Dataset): ...@@ -125,35 +123,6 @@ class Subset(dataset.Dataset):
return len(self.indices) return len(self.indices)
class DistributedSampler(Sampler):
def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
self.dataset_size = dataset_size
self.num_trainers = num_trainers
self.rank = rank
self.num_samples = int(math.ceil(dataset_size / num_trainers))
self.total_size = self.num_samples * num_trainers
assert self.total_size >= self.dataset_size
self.shuffle = shuffle
def __iter__(self):
indices = list(range(self.dataset_size))
if self.shuffle:
random.shuffle(indices)
# Append extra samples to make it evenly distributed on all trainers.
indices += indices[:(self.total_size - self.dataset_size)]
assert len(indices) == self.total_size
# Subset samples for each trainer.
indices = indices[self.rank:self.total_size:self.num_trainers]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
class LJSpeech: class LJSpeech:
def __init__(self, config, nranks, rank): def __init__(self, config, nranks, rank):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
......
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
import weight_norm
def Embedding(name_scope,
num_embeddings,
embed_dim,
padding_idx=None,
std=0.1,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
relu=False,
dropout=0.0,
act=None,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features]
if relu:
stds = [std * np.sqrt(2.0) for std in stds]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=2,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels)))
"""
# std
std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.Conv1D(
name_scope,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
class Conv1D_GU(dg.Layer):
def __init__(self,
name_scope,
conditioner_dim,
in_channels,
num_filters,
filter_size,
dilation,
causal=False,
residual=True,
dtype="float32"):
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
self.conditioner_dim = conditioner_dim
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channels should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
dtype=dtype)
self.fc = Conv1D(
self.full_name(),
conditioner_dim,
2 * num_filters,
filter_size=1,
dilation=1,
causal=False,
dtype=dtype)
def forward(self, x, skip=None, conditioner=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1DGLU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = self.conv(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def add_input(self, x, skip=None, conditioner=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = self.conv.add_input(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def Conv2DTranspose(name_scope,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
use_cudnn=True,
act=None,
dtype="float32"):
val = 1.0 / (filter_size[0] * filter_size[1])
weight_init = fluid.initializer.ConstantInitializer(val)
weight_attr = fluid.ParamAttr(initializer=weight_init)
layer = weight_norm.Conv2DTranspose(
name_scope,
num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
dilation=dilation,
param_attr=weight_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
...@@ -4,12 +4,12 @@ import time ...@@ -4,12 +4,12 @@ import time
import librosa import librosa
import numpy as np import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import fluid
import utils import utils
from data import LJSpeech from data import LJSpeech
from wavenet_modules import WaveNetModule, debug from wavenet_modules import WaveNetModule
class WaveNet(): class WaveNet():
...@@ -33,18 +33,6 @@ class WaveNet(): ...@@ -33,18 +33,6 @@ class WaveNet():
self.trainloader = dataset.trainloader self.trainloader = dataset.trainloader
self.validloader = dataset.validloader self.validloader = dataset.validloader
# if self.rank == 0:
# for i, (audios, mels, ids) in enumerate(self.validloader()):
# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype))
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
#
# for i, (audios, mels, ids) in enumerate(self.trainloader):
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
wavenet = WaveNetModule("wavenet", config, self.rank) wavenet = WaveNetModule("wavenet", config, self.rank)
# Dry run once to create and initalize all necessary parameters. # Dry run once to create and initalize all necessary parameters.
...@@ -139,8 +127,8 @@ class WaveNet(): ...@@ -139,8 +127,8 @@ class WaveNet():
self.wavenet.eval() self.wavenet.eval()
total_loss = [] total_loss = []
start_time = time.time()
sample_audios = [] sample_audios = []
start_time = time.time()
for audios, mels, audio_starts in self.validloader(): for audios, mels, audio_starts in self.validloader():
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True) loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
total_loss.append(float(loss.numpy())) total_loss.append(float(loss.numpy()))
...@@ -160,11 +148,6 @@ class WaveNet(): ...@@ -160,11 +148,6 @@ class WaveNet():
tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(), tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(),
iteration, sample_rate=self.config.sample_rate) iteration, sample_rate=self.config.sample_rate)
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
@dg.no_grad @dg.no_grad
def infer(self, iteration): def infer(self, iteration):
self.wavenet.eval() self.wavenet.eval()
...@@ -186,3 +169,8 @@ class WaveNet(): ...@@ -186,3 +169,8 @@ class WaveNet():
syn_audio.shape, syn_time)) syn_audio.shape, syn_time))
librosa.output.write_wav(filename, syn_audio, librosa.output.write_wav(filename, syn_audio,
sr=config.sample_rate) sr=config.sample_rate)
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
import itertools import itertools
import math
import numpy as np import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import ops from paddle import fluid
import weight_norm from parakeet.modules import conv, modules
def get_padding(filter_size, stride, padding_type='same'): def get_padding(filter_size, stride, padding_type='same'):
...@@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'): ...@@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'):
return padding return padding
def debug(x, var_name, rank, verbose=False):
if not verbose and rank != 0:
return
dim = len(x.shape)
if not isinstance(x, np.ndarray):
x = x.numpy()
if dim == 1:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x))
elif dim == 2:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5]))
elif dim == 3:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0]))
else:
print("Rank", rank, var_name, "shape", x.shape)
def extract_slices(x, audio_starts, audio_length, rank): def extract_slices(x, audio_starts, audio_length, rank):
slices = [] slices = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
...@@ -58,7 +40,7 @@ class Conditioner(dg.Layer): ...@@ -58,7 +40,7 @@ class Conditioner(dg.Layer):
stride = (up_scale, 1) stride = (up_scale, 1)
padding = get_padding(filter_sizes[i], stride) padding = get_padding(filter_sizes[i], stride)
self.deconvs.append( self.deconvs.append(
ops.Conv2DTranspose( modules.Conv2DTranspose(
self.full_name(), self.full_name(),
num_filters=1, num_filters=1,
filter_size=filter_sizes[i], filter_size=filter_sizes[i],
...@@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer): ...@@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer):
print("context_size", self.context_size) print("context_size", self.context_size)
if config.loss_type == "softmax": if config.loss_type == "softmax":
self.embedding_fc = ops.Embedding( self.embedding_fc = modules.Embedding(
self.full_name(), self.full_name(),
num_embeddings=config.num_channels, num_embeddings=config.num_channels,
embed_dim=config.residual_channels) embed_dim=config.residual_channels,
std=0.1)
elif config.loss_type == "mix-gaussian-pdf": elif config.loss_type == "mix-gaussian-pdf":
self.embedding_fc = ops.FC( self.embedding_fc = modules.FC(
self.full_name(), self.full_name(),
in_features=1, in_features=1,
size=config.residual_channels, size=config.residual_channels,
...@@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer): ...@@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer):
self.dilated_causal_convs = [] self.dilated_causal_convs = []
for dilation in self.dilations: for dilation in self.dilations:
self.dilated_causal_convs.append( self.dilated_causal_convs.append(
ops.Conv1D_GU( modules.Conv1D_GU(
self.full_name(), self.full_name(),
conditioner_dim=config.mel_bands, conditioner_dim=config.mel_bands,
in_channels=config.residual_channels, in_channels=config.residual_channels,
...@@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer): ...@@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer):
for i, layer in enumerate(self.dilated_causal_convs): for i, layer in enumerate(self.dilated_causal_convs):
self.add_sublayer("dilated_causal_conv_{}".format(i), layer) self.add_sublayer("dilated_causal_conv_{}".format(i), layer)
self.fc1 = ops.FC( self.fc1 = modules.FC(
self.full_name(), self.full_name(),
in_features=config.residual_channels, in_features=config.residual_channels,
size=config.skip_channels, size=config.skip_channels,
...@@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer): ...@@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer):
relu=True, relu=True,
act="relu") act="relu")
self.fc2 = ops.FC( self.fc2 = modules.FC(
self.full_name(), self.full_name(),
in_features=config.skip_channels, in_features=config.skip_channels,
size=config.skip_channels, size=config.skip_channels,
...@@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer): ...@@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer):
act="relu") act="relu")
if config.loss_type == "softmax": if config.loss_type == "softmax":
self.fc3 = ops.FC( self.fc3 = modules.FC(
self.full_name(), self.full_name(),
in_features=config.skip_channels, in_features=config.skip_channels,
size=config.num_channels, size=config.num_channels,
num_flatten_dims=2, num_flatten_dims=2,
relu=False) relu=False)
elif config.loss_type == "mix-gaussian-pdf": elif config.loss_type == "mix-gaussian-pdf":
self.fc3 = ops.FC( self.fc3 = modules.FC(
self.full_name(), self.full_name(),
in_features=config.skip_channels, in_features=config.skip_channels,
size=3 * config.num_mixtures, size=3 * config.num_mixtures,
...@@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer): ...@@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer):
return samples return samples
def sample_mix_gaussian(self, mix_parameters): def sample_mix_gaussian(self, mix_parameters):
# mix_parameters reshape from [bs, 13799, 3 * num_mixtures] # mix_parameters reshape from [bs, len, 3 * num_mixtures]
# to [bs * 13799, 3 * num_mixtures]. # to [bs * len, 3 * num_mixtures].
batch, length, hidden = mix_parameters.shape batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters, mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden]) [batch * length, hidden])
...@@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer): ...@@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer):
mu_comp = fluid.layers.gather_nd(mu, comp_samples) mu_comp = fluid.layers.gather_nd(mu, comp_samples)
s_comp = fluid.layers.gather_nd(s, comp_samples) s_comp = fluid.layers.gather_nd(s, comp_samples)
# N(0, 1) Normal Sample. # N(0, 1) normal sample.
u = fluid.layers.gaussian_random(shape=[batch * length]) u = fluid.layers.gaussian_random(shape=[batch * length])
samples = mu_comp + u * s_comp samples = mu_comp + u * s_comp
samples = fluid.layers.clip(samples, min=-1.0, max=1.0) samples = fluid.layers.clip(samples, min=-1.0, max=1.0)
...@@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer): ...@@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer):
return samples return samples
def softmax_loss(self, targets, mix_parameters): def softmax_loss(self, targets, mix_parameters):
# targets: [bs, 13799] -> [bs, 11752]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets = targets[:, self.context_size:] targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :] mix_parameters = mix_parameters[:, self.context_size:, :]
...@@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer): ...@@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer):
quantized = fluid.layers.cast( quantized = fluid.layers.cast(
(targets + 1.0) / 2.0 * num_channels, dtype="int64") (targets + 1.0) / 2.0 * num_channels, dtype="int64")
# per_sample_loss shape: [bs, 17952, 1] # per_sample_loss shape: [bs, len, 1]
per_sample_loss = fluid.layers.softmax_with_cross_entropy( per_sample_loss = fluid.layers.softmax_with_cross_entropy(
logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2)) logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2))
loss = fluid.layers.reduce_mean(per_sample_loss) loss = fluid.layers.reduce_mean(per_sample_loss)
#debug(loss, "softmax loss", self.rank)
return loss return loss
def mixture_density_loss(self, targets, mix_parameters, log_scale_min): def mixture_density_loss(self, targets, mix_parameters, log_scale_min):
# targets: [bs, 13799] -> [bs, 11752] # targets: [bs, len]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3] # mix_params: [bs, len, 3 * num_mixture]
targets = targets[:, self.context_size:] targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :] mix_parameters = mix_parameters[:, self.context_size:, :]
# log_s: [bs, 11752, num_mixture] # log_s: [bs, len, num_mixture]
logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1) logits_pi, mu, log_s = fluid.layers.split(
mix_parameters, num_or_sections=3, dim=-1)
pi = fluid.layers.softmax(logits_pi, axis=-1) pi = fluid.layers.softmax(logits_pi, axis=-1)
log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0) log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0)
...@@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer): ...@@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer):
targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures]) targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures])
x_std = inv_s * (targets - mu) x_std = inv_s * (targets - mu)
exponent = fluid.layers.exp(-0.5 * x_std * x_std) exponent = fluid.layers.exp(-0.5 * x_std * x_std)
# pdf_x: [bs, 11752, 1]
pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent
pdf_x = pi * pdf_x pdf_x = pi * pdf_x
# pdf_x: [bs, 11752] # pdf_x: [bs, len]
pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1) pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1)
per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9) per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9)
...@@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer): ...@@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer):
return loss return loss
def forward(self, audios, mels, audio_starts, sample=False): def forward(self, audios, mels, audio_starts, sample=False):
# audios: [bs, 13800], mels: [bs, full_frame_length, 80]
# audio_starts: [bs]
# Build conditioner based on mels. # Build conditioner based on mels.
full_conditioner = self.conditioner(mels) full_conditioner = self.conditioner(mels)
...@@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer): ...@@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer):
conditioner = extract_slices(full_conditioner, conditioner = extract_slices(full_conditioner,
audio_starts, audio_length, self.rank) audio_starts, audio_length, self.rank)
# input_audio, target_audio: [bs, 13799] # input_audio, target_audio: [bs, len]
input_audios = audios[:, :-1] input_audios = audios[:, :-1]
target_audios = audios[:, 1:] target_audios = audios[:, 1:]
# conditioner: [bs, 13799, 80] # conditioner: [bs, len, mel_bands]
conditioner = conditioner[:, 1:, :] conditioner = conditioner[:, 1:, :]
loss_type = self.config.loss_type loss_type = self.config.loss_type
# layer_input: [bs, 13799, 128]
if loss_type == "softmax": if loss_type == "softmax":
input_audios = fluid.layers.clip( input_audios = fluid.layers.clip(
input_audios, min=-1.0, max=0.99999) input_audios, min=-1.0, max=0.99999)
...@@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer): ...@@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer):
quantized = fluid.layers.cast( quantized = fluid.layers.cast(
(input_audios + 1.0) / 2.0 * self.config.num_channels, (input_audios + 1.0) / 2.0 * self.config.num_channels,
dtype="int64") dtype="int64")
layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2)) layer_input = self.embedding_fc(
fluid.layers.unsqueeze(quantized, 2))
elif loss_type == "mix-gaussian-pdf": elif loss_type == "mix-gaussian-pdf":
layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2)) layer_input = self.embedding_fc(
fluid.layers.unsqueeze(input_audios, 2))
else: else:
raise ValueError( raise ValueError(
"loss_type {} is unsupported!".format(loss_type)) "loss_type {} is unsupported!".format(loss_type))
# layer_input: [bs, res_channel, 1, 13799] # layer_input: [bs, res_channel, 1, len]
layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2) layer_input = fluid.layers.unsqueeze(
# conditioner: [bs, mel_bands, 1, 13799] fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2)
conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2) # conditioner: [bs, mel_bands, 1, len]
conditioner = fluid.layers.unsqueeze(
fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2)
# layer_input: [bs, res_channel, 1, 13799]
# skip: [bs, res_channel, 1, 13799]
skip = None skip = None
for i, layer in enumerate(self.dilated_causal_convs): for i, layer in enumerate(self.dilated_causal_convs):
# layer_input: [bs, res_channel, 1, len]
# skip: [bs, res_channel, 1, len]
layer_input, skip = layer(layer_input, skip, conditioner) layer_input, skip = layer(layer_input, skip, conditioner)
#debug(layer_input, "layer_input_" + str(i), self.rank)
#debug(skip, "skip_" + str(i), self.rank)
# Reshape skip to [bs, 13799, res_channel]
skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
#debug(skip, "skip", self.rank)
# mix_param: [bs, 13799, 3 * num_mixtures] # Reshape skip to [bs, len, res_channel]
skip = fluid.layers.transpose(
fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip))) mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
# Sample teacher-forced audio. # Sample teacher-forced audio.
...@@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer): ...@@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer):
else: else:
raise ValueError( raise ValueError(
"loss_type {} is unsupported!".format(loss_type)) "loss_type {} is unsupported!".format(loss_type))
#debug(sample_audios, "sample_audios", self.rank)
# Calculate mix-gaussian density loss.
# padding is all zero.
# target_audio: [bs, 13799].
# mix_params: [bs, 13799, 3].
if loss_type == "softmax": if loss_type == "softmax":
loss = self.softmax_loss(target_audios, mix_parameters) loss = self.softmax_loss(target_audios, mix_parameters)
elif loss_type == "mix-gaussian-pdf": elif loss_type == "mix-gaussian-pdf":
...@@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer): ...@@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer):
raise ValueError( raise ValueError(
"loss_type {} is unsupported!".format(loss_type)) "loss_type {} is unsupported!".format(loss_type))
#print("Rank {}, loss {}".format(self.rank, loss.numpy()))
return loss, sample_audios return loss, sample_audios
def synthesize(self, mels): def synthesize(self, mels):
self.start_new_sequence() self.start_new_sequence()
print("input mels shape", mels.shape)
# mels: [bs=1, n_frames, 80]
# conditioner: [1, n_frames * samples_per_frame, 80]
# Should I move forward by one sample? No difference
# Append context frame to mels
bs, n_frames, mel_bands = mels.shape bs, n_frames, mel_bands = mels.shape
#num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift))
#silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32")
#inf_mels = fluid.layers.concat([silence, mels], axis=1)
#print("padded mels shape", inf_mels.shape)
#conditioner = self.conditioner(inf_mels)[:, self.context_size:, :]
conditioner = self.conditioner(mels) conditioner = self.conditioner(mels)
time_steps = conditioner.shape[1] time_steps = conditioner.shape[1]
print("Total steps", time_steps)
print("input mels shape", mels.shape)
print("Total synthesis steps", time_steps)
loss_type = self.config.loss_type loss_type = self.config.loss_type
audio_samples = [] audio_samples = []
...@@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer): ...@@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer):
if i % 100 == 0: if i % 100 == 0:
print("Step", i) print("Step", i)
# convert from real value sample to audio embedding. # Convert from real value sample to audio embedding.
# [bs, 1, 128] # audio_input: [bs, 1, channel]
if loss_type == "softmax": if loss_type == "softmax":
current_sample = fluid.layers.clip( current_sample = fluid.layers.clip(
current_sample, min=-1.0, max=0.99999) current_sample, min=-1.0, max=0.99999)
...@@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer): ...@@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer):
raise ValueError( raise ValueError(
"loss_type {} is unsupported!".format(loss_type)) "loss_type {} is unsupported!".format(loss_type))
# [bs, 128, 1, 1] # [bs, channel, 1, 1]
audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2) audio_input = fluid.layers.unsqueeze(
# [bs, 80] fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2)
# [bs, mel_bands]
cond_input = conditioner[:, i, :] cond_input = conditioner[:, i, :]
# [bs, 80, 1, 1] # [bs, mel_bands, 1, 1]
cond_input = fluid.layers.reshape( cond_input = fluid.layers.reshape(
cond_input, cond_input.shape + [1, 1]) cond_input, cond_input.shape + [1, 1])
skip = None skip = None
for layer in self.dilated_causal_convs: for layer in self.dilated_causal_convs:
audio_input, skip = layer.add_input(audio_input, skip, cond_input) audio_input, skip = layer.add_input(
audio_input, skip, cond_input)
# [bs, 1, 128] # [bs, 1, channel]
skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) skip = fluid.layers.transpose(
# [bs, 1, 3] fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip))) mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
if loss_type == "softmax": if loss_type == "softmax":
sample = self.sample_softmax(mix_parameters) sample = self.sample_softmax(mix_parameters)
...@@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer): ...@@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer):
current_sample = fluid.layers.reshape(current_sample, current_sample = fluid.layers.reshape(current_sample,
current_sample.shape + [1, 1]) current_sample.shape + [1, 1])
# syn_audio: (num_samples,) # syn_audio: [num_samples]
syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy() syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy()
return syn_audio return syn_audio
def start_new_sequence(self): def start_new_sequence(self):
for layer in self.sublayers(): for layer in self.sublayers():
if isinstance(layer, weight_norm.Conv1D): if isinstance(layer, conv.Conv1D):
layer.start_new_sequence() layer.start_new_sequence()
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
此差异已折叠。
...@@ -26,6 +26,7 @@ def FC(name_scope, ...@@ -26,6 +26,7 @@ def FC(name_scope,
in_features, in_features,
size, size,
num_flatten_dims=1, num_flatten_dims=1,
relu=False,
dropout=0.0, dropout=0.0,
epsilon=1e-30, epsilon=1e-30,
act=None, act=None,
...@@ -39,7 +40,11 @@ def FC(name_scope, ...@@ -39,7 +40,11 @@ def FC(name_scope,
# stds # stds
if isinstance(in_features, int): if isinstance(in_features, int):
in_features = [in_features] in_features = [in_features]
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features] stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
if relu:
stds = [std * np.sqrt(2.0) for std in stds]
weight_inits = [ weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds fluid.initializer.NormalInitializer(scale=std) for std in stds
] ]
...@@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer): ...@@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer):
return out return out
else: else:
raise Exception("Then you can just use position rate at init") raise Exception("Then you can just use position rate at init")
class Conv1D_GU(dg.Layer):
def __init__(self,
name_scope,
conditioner_dim,
in_channels,
num_filters,
filter_size,
dilation,
causal=False,
residual=True,
dtype="float32"):
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
self.conditioner_dim = conditioner_dim
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channels should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
dtype=dtype)
self.fc = Conv1D(
self.full_name(),
conditioner_dim,
2 * num_filters,
filter_size=1,
dilation=1,
causal=False,
dtype=dtype)
def forward(self, x, skip=None, conditioner=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
skip (Variable): Shape(B, C_in, 1, T), skip connection.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1D_GU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
C_out means the output channels of Conv1D_GU.
skip (Variable): Shape(B, C_out, 1, T), skip connection.
"""
residual = x
x = self.conv(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def add_input(self, x, skip=None, conditioner=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
skip: shape(B, num_filters, 1, time_steps), skip connection
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
skip: skip connection, same shape as x
"""
residual = x
# add step input and produce step output
x = self.conv.add_input(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def Conv2DTranspose(name_scope,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
use_cudnn=True,
act=None,
dtype="float32"):
val = 1.0 / (filter_size[0] * filter_size[1])
weight_init = fluid.initializer.ConstantInitializer(val)
weight_attr = fluid.ParamAttr(initializer=weight_init)
layer = weight_norm.Conv2DTranspose(
name_scope,
num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
dilation=dilation,
param_attr=weight_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册