提交 8c22397b 编写于 作者: K Kexin Zhao

add working synthesis code

上级 f6f0a2ca
...@@ -79,17 +79,13 @@ class Subset(dataset.Dataset): ...@@ -79,17 +79,13 @@ class Subset(dataset.Dataset):
mode='constant', constant_values=0) mode='constant', constant_values=0)
# Normalize audio. # Normalize audio.
audio = audio / MAX_WAV_VALUE audio = audio.astype(np.float32) / MAX_WAV_VALUE
mel = self.get_mel(audio) mel = self.get_mel(audio)
#print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape))
return audio, mel return audio, mel
def _batch_examples(self, batch): def _batch_examples(self, batch):
audio_batch = []
mel_batch = []
for audio, mel in batch:
audio_batch
audios = [sample[0] for sample in batch] audios = [sample[0] for sample in batch]
mels = [sample[1] for sample in batch] mels = [sample[1] for sample in batch]
......
...@@ -8,11 +8,11 @@ import paddle.fluid.dygraph as dg ...@@ -8,11 +8,11 @@ import paddle.fluid.dygraph as dg
from paddle import fluid from paddle import fluid
import utils import utils
from wavenet import WaveNet from waveflow import WaveFlow
def add_options_to_parser(parser): def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='wavenet', parser.add_argument('--model', type=str, default='waveflow',
help="general name of the model") help="general name of the model")
parser.add_argument('--name', type=str, parser.add_argument('--name', type=str,
help="specific name of the training model") help="specific name of the training model")
...@@ -30,7 +30,7 @@ def add_options_to_parser(parser): ...@@ -30,7 +30,7 @@ def add_options_to_parser(parser):
parser.add_argument('--output', type=str, default="./syn_audios", parser.add_argument('--output', type=str, default="./syn_audios",
help="path to write synthesized audio files") help="path to write synthesized audio files")
parser.add_argument('--sample', type=int, parser.add_argument('--sample', type=int, default=None,
help="which of the valid samples to synthesize audio") help="which of the valid samples to synthesize audio")
...@@ -54,7 +54,7 @@ def synthesize(config): ...@@ -54,7 +54,7 @@ def synthesize(config):
print("Random Seed: ", seed) print("Random Seed: ", seed)
# Build model. # Build model.
model = WaveNet(config, checkpoint_dir) model = WaveFlow(config, checkpoint_dir)
model.build(training=False) model.build(training=False)
# Obtain the current iteration. # Obtain the current iteration.
......
...@@ -2,7 +2,8 @@ import itertools ...@@ -2,7 +2,8 @@ import itertools
import os import os
import time import time
import librosa #import librosa
from scipy.io.wavfile import write
import numpy as np import numpy as np
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import fluid from paddle import fluid
...@@ -156,17 +157,38 @@ class WaveFlow(): ...@@ -156,17 +157,38 @@ class WaveFlow():
output = "{}/{}/iter-{}".format(config.output, config.name, iteration) output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True) os.makedirs(output, exist_ok=True)
filename = "{}/valid_{}.wav".format(output, sample) mels_list = [mels for _, mels in self.validloader()]
print("Synthesize sample {}, save as {}".format(sample, filename)) if sample is not None:
mels_list = [mels_list[sample]]
mels_list = [mels for _, mels, _ in self.validloader()] audio_times = []
start_time = time.time() inf_times = []
syn_audio = self.waveflow.synthesize(mels_list[sample]) for sample, mel in enumerate(mels_list):
syn_time = time.time() - start_time filename = "{}/valid_{}.wav".format(output, sample)
print("audio shape {}, synthesis time {}".format( print("Synthesize sample {}, save as {}".format(sample, filename))
syn_audio.shape, syn_time))
librosa.output.write_wav(filename, syn_audio, start_time = time.time()
sr=config.sample_rate) audio = self.waveflow.synthesize(mel)
syn_time = time.time() - start_time
audio_time = audio.shape[0] / 22050
print("audio time {}, synthesis time {}, speedup: {}".format(
audio_time, syn_time, audio_time / syn_time))
#librosa.output.write_wav(filename, syn_audio,
# sr=config.sample_rate)
audio = audio.numpy() * 32768.0
audio = audio.astype('int16')
write(filename, config.sample_rate, audio)
audio_times.append(audio_time)
inf_times.append(syn_time)
total_audio = sum(audio_times)
total_inf = sum(inf_times)
print("Total audio: {}, total inf time {}, speedup: {}".format(
total_audio, total_inf, total_audio / total_inf))
def save(self, iteration): def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration, utils.save_latest_parameters(self.checkpoint_dir, iteration,
......
...@@ -75,6 +75,16 @@ class Conditioner(dg.Layer): ...@@ -75,6 +75,16 @@ class Conditioner(dg.Layer):
return fluid.layers.squeeze(x, [1]) return fluid.layers.squeeze(x, [1])
def infer(self, x):
x = fluid.layers.unsqueeze(x, 1)
for layer in self.upsample_conv2d:
x = layer(x)
# Trim conv artifacts.
time_cutoff = layer._filter_size[1] - layer._stride[1]
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
return fluid.layers.squeeze(x, [1])
class Flow(dg.Layer): class Flow(dg.Layer):
def __init__(self, name_scope, config): def __init__(self, name_scope, config):
...@@ -183,6 +193,14 @@ class Flow(dg.Layer): ...@@ -183,6 +193,14 @@ class Flow(dg.Layer):
return self.end(output) return self.end(output)
def debug(x, msg):
y = x.numpy()
print(msg + " :\n", y)
print("shape: ", y.shape)
print("dtype: ", y.dtype)
print("")
class WaveFlowModule(dg.Layer): class WaveFlowModule(dg.Layer):
def __init__(self, name_scope, config): def __init__(self, name_scope, config):
super(WaveFlowModule, self).__init__(name_scope) super(WaveFlowModule, self).__init__(name_scope)
...@@ -217,7 +235,7 @@ class WaveFlowModule(dg.Layer): ...@@ -217,7 +235,7 @@ class WaveFlowModule(dg.Layer):
if mel.shape[2] > pruned_len: if mel.shape[2] > pruned_len:
mel = mel[:, :, :pruned_len] mel = mel[:, :, :pruned_len]
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
# From [bs, time] to [bs, n_group, time/n_group] # From [bs, time] to [bs, n_group, time/n_group]
audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1]) audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
...@@ -247,8 +265,54 @@ class WaveFlowModule(dg.Layer): ...@@ -247,8 +265,54 @@ class WaveFlowModule(dg.Layer):
return z, log_s_list return z, log_s_list
def synthesize(self, mels): def synthesize(self, mel, sigma=1.0):
pass #debug(mel, "mel")
mel = self.conditioner.infer(mel)
#debug(mel, "mel after conditioner")
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
#debug(mel, "after group")
audio = fluid.layers.gaussian_random(
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
#debug(audio, "audio")
for i in reversed(range(self.n_flows)):
# Permute over the height dimension.
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
audio = fluid.layers.stack(audio_slices, axis=2)
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
mel = fluid.layers.stack(mel_slices, axis=2)
audio_list = []
audio_0 = audio[:, :, :1, :]
audio_list.append(audio_0)
for h in range(1, self.n_group):
# inputs: [bs, 1, h, time/n_group]
inputs = fluid.layers.concat(audio_list, axis=2)
conds = mel[:, :, 1:(h+1), :]
outputs = self.flows[i](inputs, conds)
log_s = outputs[:, :1, (h-1):h, :]
b = outputs[:, 1:, (h-1):h, :]
audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s)
audio_list.append(audio_h)
audio = fluid.layers.concat(audio_list, axis=2)
#print("audio.shape =", audio.shape)
# Assume batch size = 1
# audio: [n_group, time/n_group]
audio = fluid.layers.squeeze(audio, [0, 1])
# audio: [time]
audio = fluid.layers.reshape(
fluid.layers.transpose(audio, [1, 0]), [-1])
#print("audio.shape =", audio.shape)
return audio
def start_new_sequence(self): def start_new_sequence(self):
for layer in self.sublayers(): for layer in self.sublayers():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册