From 47915461fc2d2582d6fcdd03bd6d9935c44fbc97 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 28 Jul 2020 17:34:38 +0000 Subject: [PATCH] Adapt waveflow to internal dataset --- examples/waveflow/data.py | 5 +---- parakeet/models/waveflow/waveflow_modules.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/waveflow/data.py b/examples/waveflow/data.py index 33e2ee5..75d09b7 100644 --- a/examples/waveflow/data.py +++ b/examples/waveflow/data.py @@ -35,8 +35,7 @@ class Dataset(ljspeech.LJSpeech): fname, _, _ = metadatum wav_path = os.path.join(self.root, "wavs", fname + ".wav") - loaded_sr, audio = read(wav_path) - assert loaded_sr == self.config.sample_rate + audio, loaded_sr = librosa.load(wav_path, sr=self.config.sample_rate) return audio @@ -91,8 +90,6 @@ class Subset(DatasetMixin): mode='constant', constant_values=0) - # Normalize audio to the [-1, 1] range. - audio = audio.astype(np.float32) / 32768.0 mel = self.get_mel(audio) return audio, mel diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 31b29dc..96c5715 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -62,9 +62,8 @@ class WaveFlowLoss: class Conditioner(dg.Layer): - def __init__(self, dtype): + def __init__(self, dtype, upsample_factors): super(Conditioner, self).__init__() - upsample_factors = [16, 16] self.upsample_conv2d = [] for s in upsample_factors: @@ -296,11 +295,13 @@ class WaveFlowModule(dg.Layer): self.n_flows = config.n_flows self.n_group = config.n_group self.n_layers = config.n_layers + self.upsample_factors = config.upsample_factors if hasattr( + config, "upsample_factors") else [16, 16] assert self.n_group % 2 == 0 assert self.n_flows % 2 == 0 self.dtype = "float16" if config.use_fp16 else "float32" - self.conditioner = Conditioner(self.dtype) + self.conditioner = Conditioner(self.dtype, self.upsample_factors) self.flows = [] for i in range(self.n_flows): flow = Flow(config) @@ -397,6 +398,10 @@ class WaveFlowModule(dg.Layer): if self.dtype == "float16": mel = fluid.layers.cast(mel, self.dtype) mel = self.conditioner.infer(mel) + # Prune out the tail of mel so that time/n_group == 0. + pruned_len = int(mel.shape[2] // self.n_group * self.n_group) + if mel.shape[2] > pruned_len: + mel = mel[:, :, :pruned_len] # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) -- GitLab