提交 155dfe63 编写于 作者: C chenfeiyu 提交者: liuyibing01

add deepvoice3 model and example

上级 04d7f8b5
import os
import csv
from pathlib import Path
import numpy as np
import pandas as pd
import librosa
from scipy import signal, io
from parakeet.data import DatasetMixin, TransformDataset, FilterDataset
from parakeet.g2p.en import text_to_sequence, sequence_to_text
class LJSpeechMetaData(DatasetMixin):
def __init__(self, root):
self.root = Path(root)
self._wav_dir = self.root.joinpath("wavs")
csv_path = self.root.joinpath("metadata.csv")
self._table = pd.read_csv(
csv_path,
sep="|",
header=None,
quoting=csv.QUOTE_NONE,
names=["fname", "raw_text", "normalized_text"])
def get_example(self, i):
fname, raw_text, normalized_text = self._table.iloc[i]
fname = str(self._wav_dir.joinpath(fname + ".wav"))
return fname, raw_text, normalized_text
def __len__(self):
return len(self._table)
class Transform(object):
def __init__(self,
replace_pronounciation_prob=0.,
sample_rate=22050,
preemphasis=.97,
n_fft=1024,
win_length=1024,
hop_length=256,
fmin=125,
fmax=7600,
n_mels=80,
min_level_db=-100,
ref_level_db=20,
max_norm=0.999,
clip_norm=True):
self.replace_pronounciation_prob = replace_pronounciation_prob
self.sample_rate = sample_rate
self.preemphasis = preemphasis
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.fmin = fmin
self.fmax = fmax
self.n_mels = n_mels
self.min_level_db = min_level_db
self.ref_level_db = ref_level_db
self.max_norm = max_norm
self.clip_norm = clip_norm
def __call__(self, in_data):
fname, _, normalized_text = in_data
# text processing
mix_grapheme_phonemes = text_to_sequence(
normalized_text, self.replace_pronounciation_prob)
text_length = len(mix_grapheme_phonemes)
# CAUTION: positions start from 1
speaker_id = None
# wave processing
wav, _ = librosa.load(fname, sr=self.sample_rate)
# preemphasis
y = signal.lfilter([1., -self.preemphasis], [1.], wav)
# STFT
D = librosa.stft(y=y,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length)
S = np.abs(D)
# to db and normalize to 0-1
amplitude_min = np.exp(self.min_level_db / 20 * np.log(10)) # 1e-5
S_norm = 20 * np.log10(np.maximum(amplitude_min,
S)) - self.ref_level_db
S_norm = (S_norm - self.min_level_db) / (-self.min_level_db)
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
# mel scale and to db and normalize to 0-1,
# CAUTION: pass linear scale S, not dbscaled S
S_mel = librosa.feature.melspectrogram(S=S,
n_mels=self.n_mels,
fmin=self.fmin,
fmax=self.fmax,
power=1.)
S_mel = 20 * np.log10(np.maximum(amplitude_min,
S_mel)) - self.ref_level_db
S_mel_norm = (S_mel - self.min_level_db) / (-self.min_level_db)
S_mel_norm = self.max_norm * S_mel_norm
if self.clip_norm:
S_mel_norm = np.clip(S_mel_norm, 0, self.max_norm)
# num_frames
n_frames = S_mel_norm.shape[-1] # CAUTION: original number of frames
return (mix_grapheme_phonemes, text_length, speaker_id, S_norm,
S_mel_norm, n_frames)
class DataCollector(object):
def __init__(self, downsample_factor=4, r=1):
self.downsample_factor = int(downsample_factor)
self.frames_per_step = int(r)
self._factor = int(downsample_factor * r)
# CAUTION: small diff here
self._pad_begin = int(downsample_factor * r)
def __call__(self, examples):
batch_size = len(examples)
# lengths
text_lengths = np.array([example[1]
for example in examples]).astype(np.int64)
frames = np.array([example[5]
for example in examples]).astype(np.int64)
max_text_length = int(np.max(text_lengths))
max_frames = int(np.max(frames))
if max_frames % self._factor != 0:
max_frames += (self._factor - max_frames % self._factor)
max_frames += self._pad_begin
max_decoder_length = max_frames // self._factor
# pad time sequence
text_sequences = []
lin_specs = []
mel_specs = []
done_flags = []
for example in examples:
(mix_grapheme_phonemes, text_length, speaker_id, S_norm,
S_mel_norm, num_frames) = example
text_sequences.append(
np.pad(mix_grapheme_phonemes,
(0, max_text_length - text_length)))
lin_specs.append(
np.pad(S_norm,
((0, 0), (self._pad_begin,
max_frames - self._pad_begin - num_frames))))
mel_specs.append(
np.pad(S_mel_norm,
((0, 0), (self._pad_begin,
max_frames - self._pad_begin - num_frames))))
done_flags.append(
np.pad(np.zeros((int(np.ceil(num_frames // self._factor)), )),
(0, max_decoder_length -
int(np.ceil(num_frames // self._factor))),
constant_values=1))
text_sequences = np.array(text_sequences).astype(np.int64)
lin_specs = np.transpose(np.array(lin_specs),
(0, 2, 1)).astype(np.float32)
mel_specs = np.transpose(np.array(mel_specs),
(0, 2, 1)).astype(np.float32)
done_flags = np.array(done_flags).astype(np.float32)
# text positions
text_mask = (np.arange(1, 1 + max_text_length) <= np.expand_dims(
text_lengths, -1)).astype(np.int64)
text_positions = np.arange(1, 1 + max_text_length) * text_mask
# decoder_positions
decoder_positions = np.tile(
np.expand_dims(np.arange(1, 1 + max_decoder_length), 0),
(batch_size, 1))
return (text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags)
import os
import csv
from pathlib import Path
import numpy as np
import pandas as pd
import librosa
from scipy import signal, io
from parakeet.data import DatasetMixin, TransformDataset, FilterDataset
from parakeet.g2p.en import text_to_sequence, sequence_to_text
class LJSpeechMetaData(DatasetMixin):
def __init__(self, root):
self.root = Path(root)
csv_path = self.root.joinpath("train.txt")
self._table = pd.read_csv(
csv_path,
sep="|",
header=None,
quoting=csv.QUOTE_NONE,
names=["lin_spec", "mel_spec", "n_frames", "text"])
def get_example(self, i):
lin_spec, mel_spec, n_frames, text = self._table.iloc[i]
lin_spec = str(self.root.joinpath(lin_spec))
mel_spec = str(self.root.joinpath(mel_spec))
return lin_spec, mel_spec, n_frames, text + "\n"
def __len__(self):
return len(self._table)
class Transform(object):
def __init__(self, replace_pronounciation_prob=0.):
self.replace_pronounciation_prob = replace_pronounciation_prob
def __call__(self, in_data):
lin_spec, mel_spec, n_frames, text = in_data
# text processing
mix_grapheme_phonemes = text_to_sequence(
text, self.replace_pronounciation_prob)
text_length = len(mix_grapheme_phonemes)
# CAUTION: positions start from 1
speaker_id = None
S_norm = np.load(lin_spec).T.astype(np.float32)
S_mel_norm = np.load(mel_spec).T.astype(np.float32)
n_frames = S_mel_norm.shape[-1] # CAUTION: original number of frames
return (mix_grapheme_phonemes, text_length, speaker_id, S_norm,
S_mel_norm, n_frames)
class DataCollector(object):
def __init__(self, downsample_factor=4, r=1):
self.downsample_factor = int(downsample_factor)
self.frames_per_step = int(r)
self._factor = int(downsample_factor * r)
self._pad_begin = int(r) # int(downsample_factor * r)
def __call__(self, examples):
batch_size = len(examples)
# lengths
text_lengths = np.array([example[1]
for example in examples]).astype(np.int64)
frames = np.array([example[5]
for example in examples]).astype(np.int64)
max_text_length = int(np.max(text_lengths))
max_frames = int(np.max(frames))
if max_frames % self._factor != 0:
max_frames += (self._factor - max_frames % self._factor)
max_frames += self._factor
max_decoder_length = max_frames // self._factor
# pad time sequence
text_sequences = []
lin_specs = []
mel_specs = []
done_flags = []
for example in examples:
(mix_grapheme_phonemes, text_length, speaker_id, S_norm,
S_mel_norm, num_frames) = example
text_sequences.append(
np.pad(mix_grapheme_phonemes,
(0, max_text_length - text_length)))
lin_specs.append(
np.pad(S_norm,
((0, 0), (self._pad_begin,
max_frames - self._pad_begin - num_frames))))
mel_specs.append(
np.pad(S_mel_norm,
((0, 0), (self._pad_begin,
max_frames - self._pad_begin - num_frames))))
done_flags.append(
np.pad(np.zeros((int(np.ceil(num_frames // self._factor)), )),
(0, max_decoder_length -
int(np.ceil(num_frames // self._factor))),
constant_values=1))
text_sequences = np.array(text_sequences).astype(np.int64)
lin_specs = np.transpose(np.array(lin_specs),
(0, 2, 1)).astype(np.float32)
mel_specs = np.transpose(np.array(mel_specs),
(0, 2, 1)).astype(np.float32)
done_flags = np.array(done_flags).astype(np.float32)
# text positions
text_mask = (np.arange(1, 1 + max_text_length) <= np.expand_dims(
text_lengths, -1)).astype(np.int64)
text_positions = np.arange(1, 1 + max_text_length) * text_mask
# decoder_positions
decoder_positions = np.tile(
np.expand_dims(np.arange(1, 1 + max_decoder_length), 0),
(batch_size, 1))
return (text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags)
meta_data:
min_text_length: 20
transform:
# text
replace_pronunciation_prob: 0.5
# spectrogram
sample_rate: 22050
max_norm: 0.999
preemphasis: 0.97
n_fft: 1024
win_length: 1024
hop_length: 256
# mel
fmin: 125
fmax: 7600
n_mels: 80
# db scale
min_level_db: -100
ref_level_db: 20
loss:
masked_loss_weight: 0.5
priority_freq: 3000
priority_freq_weight: 0.0
binary_divergence_weight: 0.1
guided_attention_sigma: 0.2
synthesis:
max_steps: 512
power: 1.4
n_iter: 32
model:
# speaker_embedding
n_speakers: 1
speaker_embed_dim: 16
speaker_embedding_weight_std: 0.01
max_positions: 512
dropout: 0.050000000000000044
# encoder
text_embed_dim: 256
embedding_weight_std: 0.1
freeze_embedding: false
padding_idx: 0
encoder_channels: 256
# decoder
query_position_rate: 1.0
key_position_rate: 1.29
trainable_positional_encodings: false
kernel_size: 3
decoder_channels: 512
downsample_factor: 4
outputs_per_step: 1
# attention
key_position_rate: true
value_position_rate: true
force_monotonic_attention: true
window_backward: -1
window_ahead: 3
use_memory_mask: true
# converter
use_decoder_state_for_postnet_input: true
converter_channels: 256
optimizer:
beta1: 0.5
beta2: 0.9
epsilon: 1e-6
lr_scheduler:
warmup_steps: 4000
peak_learning_rate: 5e-4
train:
batch_size: 16
epochs: 2000
report_interval: 100
snap_interval: 1000
eval_interval: 10000
save_interval: 10000
Scientists at the CERN laboratory say they have discovered a new particle.
There's a way to measure the acute emotional intelligence that has never gone out of style.
President Trump met with other leaders at the Group of 20 conference.
Generative adversarial network or variational auto-encoder.
Please call Stella.
Some have accepted this as a miracle without any physical explanation.
\ No newline at end of file
import os
import argparse
import ruamel.yamls
import numpy as np
import soundfile as sf
from paddle import fluid
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from tensorboardX import SummaryWriter
from parakeet.g2p import en
from parakeet.utils.layer_tools import summary
from parakeet.modules.weight_norm import WeightNormWrapper
from utils import make_model, eval_model, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthsize waveform with a checkpoint.")
parser.add_argument("-c", "--config", type=str, help="experiment config.")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("text", type=str, help="text file to synthesize")
parser.add_argument("output_path", type=str, help="path to save results")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
if args.device == -1:
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(args.device)
with dg.guard(place):
# =========================model=========================
transform_config = config["transform"]
replace_pronounciation_prob = transform_config[
"replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
preemphasis = transform_config["preemphasis"]
n_fft = transform_config["n_fft"]
n_mels = transform_config["n_mels"]
model_config = config["model"]
downsample_factor = model_config["downsample_factor"]
r = model_config["outputs_per_step"]
n_speakers = model_config["n_speakers"]
speaker_dim = model_config["speaker_embed_dim"]
speaker_embed_std = model_config["speaker_embedding_weight_std"]
n_vocab = en.n_vocab
embed_dim = model_config["text_embed_dim"]
linear_dim = 1 + n_fft // 2
use_decoder_states = model_config[
"use_decoder_state_for_postnet_input"]
filter_size = model_config["kernel_size"]
encoder_channels = model_config["encoder_channels"]
decoder_channels = model_config["decoder_channels"]
converter_channels = model_config["converter_channels"]
dropout = model_config["dropout"]
padding_idx = model_config["padding_idx"]
embedding_std = model_config["embedding_weight_std"]
max_positions = model_config["max_positions"]
freeze_embedding = model_config["freeze_embedding"]
trainable_positional_encodings = model_config[
"trainable_positional_encodings"]
use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"]
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx, embedding_std, max_positions, n_vocab,
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind,
window_ahead, key_projection, value_projection,
downsample_factor, linear_dim, use_decoder_states,
converter_channels, dropout)
state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state)
for layer in dv3.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
transform_config = config["transform"]
c = transform_config["replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
min_level_db = transform_config["min_level_db"]
ref_level_db = transform_config["ref_level_db"]
preemphasis = transform_config["preemphasis"]
win_length = transform_config["win_length"]
hop_length = transform_config["hop_length"]
synthesis_config = config["synthesis"]
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
with open(args.text, "rt", encoding="utf-8") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
text = line[:-1]
dv3.eval()
wav, attn = eval_model(dv3, text, replace_pronounciation_prob,
min_level_db, ref_level_db, power,
n_iter, win_length, hop_length,
preemphasis)
plot_alignment(
attn,
os.path.join(args.output_path, "test_{}.png".format(idx)))
sf.write(
os.path.join(args.output_path, "test_{}.wav".format(idx)),
wav, sample_rate)
import os
import argparse
import ruamel.yamls
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
import librosa
from librosa import display
import soundfile as sf
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from parakeet.g2p import en
from parakeet.models.deepvoice3.encoder import ConvSpec
from parakeet.data import FilterDataset, TransformDataset, FilterDataset
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, plot_alignment, plot_alignments, save_state, make_output_tree
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument("-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument("-g",
"--device",
type=int,
default=-1,
help="device to use")
args, _ = parser.parse_known_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
# =========================dataset=========================
# construct meta data
data_root = args.data
meta = LJSpeechMetaData(data_root)
# filter it!
min_text_length = config["meta_data"]["min_text_length"]
meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length)
# transform meta data into meta data
transform_config = config["transform"]
replace_pronounciation_prob = transform_config[
"replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
preemphasis = transform_config["preemphasis"]
n_fft = transform_config["n_fft"]
win_length = transform_config["win_length"]
hop_length = transform_config["hop_length"]
fmin = transform_config["fmin"]
fmax = transform_config["fmax"]
n_mels = transform_config["n_mels"]
min_level_db = transform_config["min_level_db"]
ref_level_db = transform_config["ref_level_db"]
max_norm = transform_config["max_norm"]
clip_norm = transform_config["clip_norm"]
transform = Transform(replace_pronounciation_prob, sample_rate,
preemphasis, n_fft, win_length, hop_length, fmin,
fmax, n_mels, min_level_db, ref_level_db, max_norm,
clip_norm)
ljspeech = TransformDataset(meta, transform)
# =========================dataiterator=========================
# use meta data's text length as a sort key for the sampler
train_config = config["train"]
batch_size = train_config["batch_size"]
text_lengths = [len(example[2]) for example in meta]
sampler = PartialyRandomizedSimilarTimeLengthSampler(
text_lengths, batch_size)
# some hyperparameters affect how we process data, so create a data collector!
model_config = config["model"]
downsample_factor = model_config["downsample_factor"]
r = model_config["outputs_per_step"]
collector = DataCollector(downsample_factor=downsample_factor, r=r)
ljspeech_loader = DataCargo(ljspeech,
batch_fn=collector,
batch_size=batch_size,
sampler=sampler)
# =========================model=========================
if args.device == -1:
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(args.device)
with dg.guard(place):
# =========================model=========================
n_speakers = model_config["n_speakers"]
speaker_dim = model_config["speaker_embed_dim"]
speaker_embed_std = model_config["speaker_embedding_weight_std"]
n_vocab = en.n_vocab
embed_dim = model_config["text_embed_dim"]
linear_dim = 1 + n_fft // 2
use_decoder_states = model_config[
"use_decoder_state_for_postnet_input"]
filter_size = model_config["kernel_size"]
encoder_channels = model_config["encoder_channels"]
decoder_channels = model_config["decoder_channels"]
converter_channels = model_config["converter_channels"]
dropout = model_config["dropout"]
padding_idx = model_config["padding_idx"]
embedding_std = model_config["embedding_weight_std"]
max_positions = model_config["max_positions"]
freeze_embedding = model_config["freeze_embedding"]
trainable_positional_encodings = model_config[
"trainable_positional_encodings"]
use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"]
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx, embedding_std, max_positions, n_vocab,
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind,
window_ahead, key_projection, value_projection,
downsample_factor, linear_dim, use_decoder_states,
converter_channels, dropout)
# =========================loss=========================
loss_config = config["loss"]
masked_weight = loss_config["masked_loss_weight"]
priority_freq = loss_config["priority_freq"] # Hz
priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim)
priority_freq_weight = loss_config["priority_freq_weight"]
binary_divergence_weight = loss_config["binary_divergence_weight"]
guided_attention_sigma = loss_config["guided_attention_sigma"]
criterion = TTSLoss(masked_weight=masked_weight,
priority_bin=priority_bin,
priority_weight=priority_freq_weight,
binary_divergence_weight=binary_divergence_weight,
guided_attention_sigma=guided_attention_sigma,
downsample_factor=downsample_factor,
r=r)
# =========================lr_scheduler=========================
lr_config = config["lr_scheduler"]
warmup_steps = lr_config["warmup_steps"]
peak_learning_rate = lr_config["peak_learning_rate"]
lr_scheduler = dg.NoamDecay(
1 / (warmup_steps * (peak_learning_rate)**2), warmup_steps)
# =========================optimizer=========================
optim_config = config["optimizer"]
beta1 = optim_config["beta1"]
beta2 = optim_config["beta2"]
epsilon = optim_config["epsilon"]
optim = fluid.optimizer.Adam(lr_scheduler,
beta1,
beta2,
epsilon=epsilon,
parameter_list=dv3.parameters())
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(capacity=10,
return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
# tensorboard & checkpoint preparation
output_dir = args.output
ckpt_dir = os.path.join(output_dir, "checkpoints")
log_dir = os.path.join(output_dir, "log")
state_dir = os.path.join(output_dir, "states")
make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir)
# load model parameters
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# =========================train=========================
epoch = train_config["epochs"]
report_interval = train_config["report_interval"]
snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"]
global_step = 1
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
for j in range(1, 1 + epoch):
epoch_loss = {"mel": 0., "lin": 0., "done": 0., "attn": 0.}
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
l = criterion.compose_loss(losses)
l.backward()
optim.minimize(l, grad_clip=gradient_clipper)
dv3.clear_gradients()
# ==================all kinds of tedious things=================
for k in epoch_loss.keys():
epoch_loss[k] += losses[k].numpy()[0]
average_loss[k] += losses[k].numpy()[0]
# record step loss into tensorboard
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
linear_outputs_np = linear_outputs.numpy()[0].T
denoramlized = np.clip(linear_outputs_np, 0, 1) \
* (-min_level_db) \
+ min_level_db
lin_scaled = np.exp(
(denoramlized + ref_level_db) / 20 * np.log(10))
synthesis_config = config["synthesis"]
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
wav = librosa.griffinlim(lin_scaled**power,
n_iter=n_iter,
hop_length=hop_length,
win_length=win_length)
save_state(state_dir,
global_step,
mel_input=mel_specs.numpy()[0].T,
mel_output=mel_outputs.numpy()[0].T,
lin_input=lin_specs.numpy()[0].T,
lin_output=linear_outputs.numpy()[0].T,
alignments=alignments.numpy()[:, 0, :, :],
wav=wav)
# evaluation
if global_step % eval_interval == 0:
sentences = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in sentences:
wav, attn = eval_model(dv3, sent,
replace_pronounciation_prob,
min_level_db, ref_level_db,
power, n_iter, win_length,
hop_length, preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(dv3.state_dict(),
os.path.join(ckpt_dir, "dv3"))
dg.save_dygraph(optim.state_dict(),
os.path.join(ckpt_dir, "dv3"))
# report average loss
if global_step % report_interval == 0:
for k in epoch_loss.keys():
average_loss[k] /= report_interval
print("[average_loss] ",
"global_step: {}".format(global_step), average_loss)
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
global_step += 1
# epoch report
for k in epoch_loss.keys():
epoch_loss[k] /= i
print("[epoch_loss] ", "epoch: {}".format(j), epoch_loss)
\ No newline at end of file
import os
import argparse
import numpy as np
import pandas as pd
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
import librosa
from scipy import signal
from librosa import display
import soundfile as sf
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from parakeet.g2p import en
from parakeet.models.Rdeepvoice3.encoder import ConvSpec
from parakeet.data import FilterDataset, TransformDataset, FilterDataset, DatasetMixin
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
from parakeet.models.Rdeepvoice3 import Encoder, Decoder, Converter, DeepVoice3
from parakeet.models.Rdeepvoice3.loss import TTSLoss
from parakeet.modules.weight_norm_wrapper import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from data_validate import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, plot_alignment, plot_alignments, save_state
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech")
parser.add_argument("-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument("-d",
"--data",
type=str,
default="/workspace/datasets/ljs_dv3",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
args, _ = parser.parse_known_args()
# =========================dataset=========================
data_root = args.data
meta = LJSpeechMetaData(data_root) # construct meta data
#meta = FilterDataset(meta, lambda x: len(x[3]) >= 20) # filter it!
transform = Transform()
ljspeech = TransformDataset(meta, transform)
# =========================dataiterator=========================
# use meta data's text length as a sort key
# which is used in sampler
text_lengths = [len(example[3]) for example in meta]
# some hyperparameters affect how we process data, so create a data collector!
collector = DataCollector(downsample_factor=4., r=1)
ljspeech_loader = DataCargo(ljspeech,
batch_fn=collector,
batch_size=16,
sampler=SequentialSampler(ljspeech))
# sampler=PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
# batch_size=32))
# ljspeech_iterator = ljspeech_loader() # if you want to inspect it!
# for i in range(3):
# batch = next(ljspeech_iterator)
# print(batch)
# =========================model=========================
sample_rate = 22050
n_speakers = 1
speaker_dim = 16
n_vocab = en.n_vocab
embed_dim = 256
mel_dim = 80
downsample_factor = 4
r = 1
linear_dim = 1 + 1024 // 2
use_decoder_states = True
filter_size = 3
encoder_channels = 512
decoder_channels = 256
converter_channels = 256
dropout = 0. #0.050000000000000044
place = fluid.CPUPlace()
with dg.guard(place):
# =========================model=========================
dv3 = make_model(n_speakers, speaker_dim, n_vocab, embed_dim, mel_dim,
downsample_factor, r, linear_dim, use_decoder_states,
filter_size, encoder_channels, decoder_channels,
converter_channels, dropout)
# =========================loss=========================
priority_freq = 3000 # Hz
priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim)
criterion = TTSLoss(masked_weight=.5,
priority_bin=priority_bin,
priority_weight=.0,
binary_divergence_weight=.1,
guided_attention_sigma=.2,
downsample_factor=downsample_factor,
r=r)
# summary(dv3)
# =========================lr_scheduler=========================
warmup_steps = 4000
peak_learning_rate = 5e-4
lr_scheduler = dg.NoamDecay(d_model=1 / (warmup_steps *
(peak_learning_rate)**2),
warmup_steps=warmup_steps)
# =========================optimizer=========================
beta1, beta2 = 0.5, 0.9
epsilon = 1e-6
optim = fluid.optimizer.Adam(lr_scheduler,
beta1,
beta2,
epsilon=1e-6,
parameter_list=dv3.parameters())
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(capacity=10,
return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
# tensorboard & checkpoint preparation
output_dir = args.output
ckpt_dir = os.path.join(output_dir, "checkpoints")
state_dir = os.path.join(output_dir, "states")
log_dir = os.path.join(output_dir, "log")
for x in [ckpt_dir, state_dir]:
if not os.path.exists(x):
os.makedirs(x)
for x in ["alignments", "waveform", "lin_spec", "mel_spec"]:
p = os.path.join(state_dir, x)
if not os.path.exists(p):
os.makedirs(p)
writer = SummaryWriter(logdir=log_dir)
# DEBUG
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# =========================train=========================
epoch = 3000
global_step = 1
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
epoch_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
for j in range(epoch):
for i, batch in tqdm.tqdm(enumerate(loader)):
dv3.train() # switch to train
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
# print("========")
# print("text lengths: {}".format(text_lengths.numpy()))
# print("n frames: {}".format(frames.numpy()))
# print("[mel] mel's shape: {}; "
# "downsampled mel's shape: {}; "
# "output's shape: {}".format(mel_specs.shape,
# downsampled_mel_specs.shape,
# mel_outputs.shape))
# print("[lin] lin's shape: {}; "
# "output's shape{}".format(lin_specs.shape,
# linear_outputs.shape))
# print("[attn]: alignments's shape: {}".format(alignments.shape))
# print("[done]: input done flag's shape: {}; "
# "output done flag's shape: {}".format(
# done_flags.shape, done.shape))
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
for k in epoch_loss.keys():
epoch_loss[k] += losses[k].numpy()[0]
average_loss[k] += losses[k].numpy()[0]
global_step += 1
# train state saving, the first sentence in the batch
if global_step > 0 and global_step % 10 == 0:
linear_outputs_np = linear_outputs.numpy()[0].T
denoramlized = np.clip(linear_outputs_np, 0,
1) * 100. - 100.
lin_scaled = np.exp((denoramlized + 20) / 20 * np.log(10))
wav = librosa.griffinlim(lin_scaled**1.4,
n_iter=32,
hop_length=256,
win_length=1024)
save_state(state_dir,
global_step,
mel_input=mel_specs.numpy()[0].T,
mel_output=mel_outputs.numpy()[0].T,
lin_input=lin_specs.numpy()[0].T,
lin_output=linear_outputs.numpy()[0].T,
alignments=alignments.numpy()[:, 0, :, :],
wav=wav)
# evaluation
if global_step > 0 and global_step % 10 == 0:
wav, attn = eval_model(
dv3,
"Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition"
)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{}.wav".format(global_step))
sf.write(wav_path, wav, 22050)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{}.png".format(global_step))
plot_alignment(attn, attn_path)
# for tensorboard writer, if you want more, write more
# cause you are in the process
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# save checkpoint
if global_step % 1000 == 0:
for i, attn_layer in enumerate(
alignments.numpy()[:, 0, :, :]):
plt.figure()
plt.imshow(attn_layer)
plt.xlabel("encoder_timesteps")
plt.ylabel("decoder_timesteps")
plt.savefig("results3/step_{}_layer_{}.png".format(
global_step, i),
format="png")
plt.close()
# print(step_loss)
if global_step % 100 == 0:
for k in epoch_loss.keys():
average_loss[k] /= 100
print("[average_loss] ",
"global_step: {}".format(global_step), average_loss)
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
l = criterion.compose_loss(losses)
l.backward()
# print("loss: ", l.numpy()[0])
optim.minimize(
l,
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(
0.1))
dv3.clear_gradients()
if global_step % 10000 == 0:
dg.save_dygraph(dv3.state_dict(),
os.path.join(ckpt_dir, "dv3"))
dg.save_dygraph(optim.state_dict(),
os.path.join(ckpt_dir, "dv3"))
for k in epoch_loss.keys():
epoch_loss[k] /= (i + 1)
print("[epoch_loss] ", "epoch: {}".format(j + 1), epoch_loss)
epoch_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
import os
import numpy as np
import matplotlib.pyplot as plt
import librosa
from scipy import signal
from librosa import display
import soundfile as sf
from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.initializer as I
from parakeet.g2p import en
from parakeet.models.deepvoice3.encoder import ConvSpec
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, WindowRange
from parakeet.utils.layer_tools import freeze
@fluid.framework.dygraph_only
def make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx, embedding_std, max_positions, n_vocab,
freeze_embedding, filter_size, encoder_channels, mel_dim,
decoder_channels, r, trainable_positional_encodings,
use_memory_mask, query_position_rate, key_position_rate,
window_behind, window_ahead, key_projection, value_projection,
downsample_factor, linear_dim, use_decoder_states,
converter_channels, dropout):
"""just a simple function to create a deepvoice 3 model"""
if n_speakers > 1:
spe = dg.Embedding((n_speakers, speaker_dim),
param_attr=I.Normal(scale=speaker_embed_std))
else:
spe = None
h = encoder_channels
k = filter_size
encoder_convolutions = (
ConvSpec(h, k, 1),
ConvSpec(h, k, 3),
ConvSpec(h, k, 9),
ConvSpec(h, k, 27),
ConvSpec(h, k, 1),
ConvSpec(h, k, 3),
ConvSpec(h, k, 9),
ConvSpec(h, k, 27),
ConvSpec(h, k, 1),
ConvSpec(h, k, 3),
)
enc = Encoder(n_vocab,
embed_dim,
n_speakers,
speaker_dim,
padding_idx=padding_idx,
embedding_weight_std=embedding_std,
convolutions=encoder_convolutions,
max_positions=max_positions,
dropout=dropout)
if freeze_embedding:
freeze(enc.embed)
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
attentive_convolutions = (
ConvSpec(h, k, 1),
ConvSpec(h, k, 3),
ConvSpec(h, k, 9),
ConvSpec(h, k, 27),
ConvSpec(h, k, 1),
)
attention = [True, False, False, False, True]
force_monotonic_attention = [True, False, False, False, True]
dec = Decoder(n_speakers,
speaker_dim,
embed_dim,
mel_dim,
r=r,
max_positions=max_positions,
padding_idx=padding_idx,
preattention=prenet_convolutions,
convolutions=attentive_convolutions,
attention=attention,
dropout=dropout,
use_memory_mask=use_memory_mask,
force_monotonic_attention=force_monotonic_attention,
query_position_rate=query_position_rate,
key_position_rate=key_position_rate,
window_range=WindowRange(window_behind, window_ahead),
key_projection=key_projection,
value_projection=value_projection)
if not trainable_positional_encodings:
freeze(dec.embed_keys_positions)
freeze(dec.embed_query_positions)
h = converter_channels
postnet_convolutions = (
ConvSpec(h, k, 1),
ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1),
ConvSpec(2 * h, k, 3),
)
cvt = Converter(n_speakers,
speaker_dim,
dec.state_dim if use_decoder_states else mel_dim,
linear_dim,
time_upsampling=downsample_factor,
convolutions=postnet_convolutions,
dropout=dropout)
dv3 = DeepVoice3(enc, dec, cvt, spe, use_decoder_states)
return dv3
@fluid.framework.dygraph_only
def eval_model(model, text, replace_pronounciation_prob, min_level_db,
ref_level_db, power, n_iter, win_length, hop_length,
preemphasis):
"""generate waveform from text using a deepvoice 3 model"""
text = np.array(en.text_to_sequence(text, p=replace_pronounciation_prob),
dtype=np.int64)
length = len(text)
print("text sequence's length: {}".format(length))
text_positions = np.arange(1, 1 + length)
text = np.expand_dims(text, 0)
text_positions = np.expand_dims(text_positions, 0)
mel_outputs, linear_outputs, alignments, done = model.transduce(
dg.to_variable(text), dg.to_variable(text_positions))
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
print("linear_outputs's shape: ", linear_outputs_np.shape)
denoramlized = np.clip(linear_outputs_np, 0,
1) * (-min_level_db) + min_level_db
lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10))
wav = librosa.griffinlim(lin_scaled**power,
n_iter=n_iter,
hop_length=hop_length,
win_length=win_length)
wav = signal.lfilter([1.], [1., -preemphasis], wav)
print("alignmnets' shape:", alignments.shape)
alignments_np = alignments.numpy()[0].T
return wav, alignments_np
def make_output_tree(output_dir):
print("creating output tree: {}".format(output_dir))
ckpt_dir = os.path.join(output_dir, "checkpoints")
state_dir = os.path.join(output_dir, "states")
log_dir = os.path.join(output_dir, "log")
for x in [ckpt_dir, state_dir]:
if not os.path.exists(x):
os.makedirs(x)
for x in ["alignments", "waveform", "lin_spec", "mel_spec"]:
p = os.path.join(state_dir, x)
if not os.path.exists(p):
os.makedirs(p)
def plot_alignment(alignment, path, info=None):
"""
Plot an attention layer's alignment for a sentence.
alignment: shape(T_enc, T_dec), and T_enc is flipped
"""
fig, ax = plt.subplots()
im = ax.imshow(alignment,
aspect='auto',
origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
plt.savefig(path)
plt.close()
def plot_alignments(alignments, save_dir, global_step):
"""
Plot alignments for a sentence when training, we just pick the first
sentence. Each layer is plot separately.
alignments: shape(N, T_dec, T_enc)
"""
n_layers = alignments.shape[0]
for i, alignment in enumerate(alignments):
alignment = alignment.T
path = os.path.join(save_dir, "layer_{}".format(i))
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.join(path, "step_{:09d}".format(global_step))
plot_alignment(alignment, fname)
average_alignment = np.mean(alignments, axis=0).T
path = os.path.join(save_dir, "average")
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.join(path, "step_{:09d}.png".format(global_step))
plot_alignment(average_alignment, fname)
def save_state(save_dir,
global_step,
mel_input=None,
mel_output=None,
lin_input=None,
lin_output=None,
alignments=None,
wav=None):
if mel_input is not None and mel_output is not None:
path = os.path.join(save_dir, "mel_spec")
if not os.path.exists(path):
os.makedirs(path)
plt.figure(figsize=(10, 3))
display.specshow(mel_input)
plt.colorbar()
plt.title("mel_input")
plt.savefig(
os.path.join(path,
"target_mel_spec_step{:09d}".format(global_step)))
plt.close()
plt.figure(figsize=(10, 3))
display.specshow(mel_output)
plt.colorbar()
plt.title("mel_input")
plt.savefig(
os.path.join(path,
"predicted_mel_spec_step{:09d}".format(global_step)))
plt.close()
if lin_input is not None and lin_output is not None:
path = os.path.join(save_dir, "lin_spec")
if not os.path.exists(path):
os.makedirs(path)
plt.figure(figsize=(10, 3))
display.specshow(lin_input)
plt.colorbar()
plt.title("mel_input")
plt.savefig(
os.path.join(path,
"target_lin_spec_step{:09d}".format(global_step)))
plt.close()
plt.figure(figsize=(10, 3))
display.specshow(lin_output)
plt.colorbar()
plt.title("mel_input")
plt.savefig(
os.path.join(path,
"predicted_lin_spec_step{:09d}".format(global_step)))
plt.close()
if alignments is not None and len(alignments.shape) == 3:
path = os.path.join(save_dir, "alignments")
if not os.path.exists(path):
os.makedirs(path)
plot_alignments(alignments, path, global_step)
if wav is not None:
path = os.path.join(save_dir, "waveform")
if not os.path.exists(path):
os.makedirs(path)
sf.write(
os.path.join(path, "sample_step_{:09d}.wav".format(global_step)),
wav, 22050)
# Deep Voice 3 with Paddle Fluid
[中文版](README_cn.md)
Paddle fluid implementation of DeepVoice 3, a convolutional network based text-to-speech synthesis model. The implementation is based on [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654).
We implement Deepvoice3 model in paddle fluid with dynamic graph, which is convenient for flexible network architectures.
## Installation
You additionally need to download punkt and cmudict for nltk, because we tokenize text with `punkt` and convert text into phonemes with `cmudict`.
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## Model Architecture
![DeepVoice3 model architecture](./_images/model_architecture.png)
The model consists of an encoder, a decoder and a converter (and a speaker embedding for multispeaker models). The encoder, together with the decoder forms the seq2seq part of the model, and the converter forms the postnet part.
## Project Structure
```text
├── audio.py # audio processing
├── compute_timestamp_ratio.py # script to compute position rate
├── conversion # parameter conversion from pytorch model
├── requirements.txt # requirements
├── hparams.py # HParam class for deepvoice3
├── hparam_tf # hyper parameter related stuffs
├── ljspeech.py # functions for ljspeech preprocessing
├── preprocess.py # preprocrssing script
├── presets # preset hyperparameters
├── deepvoice3_paddle # DeepVoice3 model implementation
├── eval_model.py # functions for model evaluation
├── synthesis.py # script for speech synthesis
├── train_model.py # functions for model training
└── train.py # script for model training
```
## Usage
There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good are provided in the repository. See `presets` directory for details. Now we only provide preset with LJSpeech dataset (`deepvoice3_ljspeech.json`). Support for more models and datasets is pending.
Note that `preprocess.py`, `train.py` and `synthesis.py` all accept a `--preset` parameter. To ensure consistency, you should use the same preset for preprocessing, training and synthesizing.
Note that you can overwrite preset hyperparameters with command line argument `--hparams`, just pass several key-value pair in `${key}=${value}` format seperated by comma (`,`). For example `--hparams="batch_size=8, nepochs=500"` can overwrite default values in the preset json file.
Some hyperparameters are only related to training, like `batch_size`, `checkpoint_interval` and you can use different values for preprocessing and training. But hyperparameters related to data preprocessing, like `num_mels` and `ref_level_db`, should be kept the same for preprocessing and training.
For more details about hyperparameters, see `hparams.py`, which contains the definition of `hparams`. Priority order of hyperparameters is command line option `--hparams` > `--preset` json configuration file > definition of hparams in `hparams.py`.
### Dataset
Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
Preprocessing with `preprocess.py`.
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
Now `${name}$` only supports `ljspeech`. Support for other datasets is pending.
Assuming that you use `presers/deepvoice3_ljspeech.json` for LJSpeech and the path of the unziped dataset is `./data/LJSpeech-1.1`, then you can preprocess data with the following command.
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ./data/LJSpeech-1.1/ ./data/ljspeech
```
When this is done, you will see extracted features in `./data/ljspeech` including:
1. text and corresponding file names for the extracted features in `train.txt`.
2. mel-spectrogram in `ljspeech-mel-*.npy` .
3. linear-spectrogram in `ljspeech-spec-*.npy`.
### Train on single GPU
Training the whole model on one single GPU:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
For more details about `train.py`, see `python train.py --help`.
#### load checkpoints
We provide a trained model ([dv3.single_frame](https://paddlespeech.bj.bcebos.com/Parakeet/dv3.single_frame.tar.gz)) for downloading, which is trained with the default preset. Unzip the downloaded file with `tar xzvf dv3.single_frame.tar.gz`, you will get `config.json`, `model.pdparams` and `README.md`. `config.json` is the preset json with which the model is trained, `model.pdparams` is the parameter file, and `README.md` is a brief introduction of the model.
You can load saved checkpoint and resume training with `--checkpoint` (You only need to provide the base name of the parameter file, eg. if you want to load `model.pdparams`, just use `--checkpoint=model`). If there is also a file with the same basename and extension name `.pdopt` in the same folder with the model file (i.e. `model.pdopt`, which is the optimizer file), it is also loaded automatically. If you wan to reset optimizer states, pass `--reset-optimizer` in addition.
#### train a part of the model
You can also train parts of the model while freezing other parts, by passing `--train-seq2seq-only` or `--train-postnet-only`. When training only parts of the model, other parts should be loaded from saved checkpoint.
To train only the `seq2seq` or `postnet`, you should load from a whole model with `--checkpoint` and keep the same configurations with which the checkpoint is trained. Note that when training only the `postnet`, you should set `use_decoder_state_for_postnet_input=false`, because when train only the postnet, the postnet takes the ground truth mel-spectrogram as input. Note that the default value for `use_decoder_state_for_postnet_input` is `True`.
example:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--output=${directory_to_save_results}
```
### Training on multiple GPUs
Training on multiple GPUs with data parallel is enabled. You can run `train.py` with `paddle.distributed.launch` module. Here is the command line usage.
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_of_write_log} \
training_script ...
```
`paddle.distributed.launch` parallelizes training in multiprocessing mode.`--selected_gpus` means the logical ids of the selected GPUs, and `started_port` means the port used by the first worker. Outputs of each process are saved in `--log_dir.` Then follows the command for training on a single GPU, except that you should pass `--use-data-paralle` in addition.
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
In the example above, we set only GPU `2, 3, 4, 5` to be visible. Then `--selected_gpus="0, 1, 2, 3"` means the logical ids of the selected gpus, which correponds to GPU `2, 3, 4, 5`.
Model checkpoints (`*.pdparams` for the model and `*.pdopt` for the optimizer) are saved in `${directory_to_save_results}/checkpoints` per 10000 steps by default. Layer-wise averaged attention alignments (.png) are saved in `${directory_to_save_results}/checkpoints/alignment_ave`. And alignments for each attention layer are saved in `${directory_to_save_results}/checkpoints/alignment_layer{attention_layer_num}` per 10000 steps for inspection.
Synthesis results of 6 sentences (hardcoded in `eval_model.py`) are saved in `${directory_to_save_results}/checkpoints/eval`, including `step{step_num}_text{text_id}_single_alignment.png` for averaged alignments and `step{step_num}_text{text_id}_single_predicted.wav` for the predicted waveforms.
### Monitor with Tensorboard
Logs with tensorboard are saved in `${directory_to_save_results}/log/` directory by default. You can monitor logs by tensorboard.
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### Synthesize from a checkpoint
Given a list of text, `synthesis.py` synthesize audio signals from a trained model.
```bash
python synthesis.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}
```
Example test_list.txt:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
generated waveform files and alignment files are saved in `${dst_dir}`.
### Compute position ratio
According to [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), the position rate is different for different datasets. There are 2 position rates, one for the query and the other for the key, which are referred to as $\omega_1$ and $\omega_2$ in th paper, and the corresponding names in preset json are `query_position_rate` and `key_position_rate`.
For example, the `query_position_rate` and `key_position_rate` for LJSpeech are `1.0` and `1.385`, respectively. Fix the `query_position_rate` as 1.0, the `key_position_rate` is computed with `compute_timestamp_ratio.py`. Run the command below, where `${data_root}` means the path of the preprocessed dataset.
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
You will get outputs like this.
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
Then set the `key_position_rate=1.385` and `query_position_rate=1.0` in the preset.
## Acknowledgement
We thankfully included and adapted some files from r9y9's [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch).
# Deep Voice 3 with Paddle Fluid
[English](README.md)
Paddle 实现的 Deepvoice3,一个基于卷积神经网络的语音合成 (Text to Speech) 模型。本实现基于 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654)
本 Deepvoice3 实现使用 Paddle 动态图模式,这对于灵活的网络结构更为方便。
## 安装
### 安装 paddlepaddle 框架
为了更快的训练速度和更好的支持,我们推荐使用最新的开发版 paddle。用户可以最新编译的开发版 whl 包,也可以选择从源码编译 Paddle。
1. 下载最新编译的开发版 whl 包。可以从 [**多版本 wheel 包列表-dev**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#whl-dev) 页面中选择合适的版本。
2. 从源码编译 Paddle. 参考[**从源码编译**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/compile/fromsource.html) 页面。注意,如果你需要使用多卡训练,那么编译前需要设置选项 `-DWITH_DISTRIBUTE=ON`
### 其他依赖
使用 pip 安装其他依赖。
```bash
pip install -r requirements.txt
```
另外需要下载 nltk 的两个库,因为使用了 `punkt` 对文本进行 tokenization,并且使用了 `cmudict` 来将文本转为音位。
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## 模型结构
![DeepVoice3 模型结构](./_images/model_architecture.png)
模型包含 encoder, decoder, converter 几个部分,对于 multispeaker 数据集,还有一个 speaker embedding。其中 encoder 和 decoder 构成 seq2seq 部分,converter 构成 postnet 部分。
## 项目结构
```text
├── audio.py # 用于处理处理音频的函数
├── compute_timestamp_ratio.py # 计算 position rate 的脚本
├── conversion # 用于转换 pytorch 实现的参数
├── requirements.txt # 项目依赖
├── hparams.py # DeepVoice3 运行超参数配置类的定义
├── hparam_tf # 超参数相关
├── ljspeech.py # ljspeech 数据集预处理
├── preprocess.py # 通用预处理脚本
├── presets # 预设超参数配置
├── deepvoice3_paddle # DeepVoice3 模型实现的主要文件
├── eval_model.py # 模型测评相关函数
├── synthesis.py # 用于语音合成的脚本
├── train_model.py # 模型训练相关函数
└── train.py # 用于模型训练的脚本
```
## 使用方法
根据所使用的模型配置和数据集的不同,有不少超参数需要进行调节。我们提供已知结果较好的超参数设置,详见 `presets` 文件夹。目前我们只提供 LJSpeech 的预设配置 (`deepvoice3_ljspeech.json`)。后续将提供更多模型和数据集的预设配置。
`preprocess.py``train.py``synthesis.py` 都接受 `--preset` 参数。为了保持一致性,最好在数据预处理,模型训练和语音合成时使用相同的预设配置。
可以通过 `--hparams` 参数来覆盖预设的超参数配置,参数格式是逗号分隔的键值对 `${key}=${value}`,例如 `--hparams="batch_size=8, nepochs=500"`
部分参数只和训练有关,如 `batch_size`, `checkpoint_interval`, 用户在训练时可以使用不同的值。但部分参数和数据预处理相关,如 `num_mels``ref_level_db`, 这些参数在数据预处理和训练时候应该保持一致。
关于超参数设置更多细节可以参考 `hparams.py` ,其中定义了 hparams。超参数的优先级序列是:通过命令行参数 `--hparams` 传入的参数优先级高于通过 `--preset` 参数传入的 json 配置文件,高于 `hparams.py` 中的定义。
### 数据集
下载并解压 [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) 数据集。
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
使用 `preprocess.py`进行预处理。
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
目前 `${name}$` 只支持 `ljspeech`。未来将会支持更多数据集。
假设你使用 `presers/deepvoice3_ljspeech.json` 作为处理 LJSpeech 的预设配置文件,并且解压后的数据集位于 `./data/LJSpeech-1.1`, 那么使用如下的命令进行数据预处理。
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ./data/LJSpeech-1.1/ ./data/ljspeech
```
数据处理完成后,你会在 `./data/ljspeech` 看到提取的特征,包含如下文件。
1. `train.txt`,包含文本和对应的音频特征的文件名。
2. `ljspeech-mel-*.npy`,包含 mel 频谱。
3. `ljspeech-spec-*.npy`,包含线性频谱。
### 使用 GPU 单卡训练
在单个 GPU 上训练整个模型的使用方法如下。
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
用于可以通过 `python train.py --help` 查看 `train.py` 的详细使用方法。
#### 加载保存的模型
我们提供了使用默认的配置文件训练的模型 [dv3.single_frame](https://paddlespeech.bj.bcebos.com/Parakeet/dv3.single_frame.tar.gz) 供用户下载。使用 `tar xzvf dv3.single_frame.tar.gz` 解压下载的文件,会得到 `config.json`, `model.pdparams` and `README.md`。其中 `config.json` 是模型训练时使用的配置文件,`model.pdparams` 是参数文件,`README.md` 是模型的简要说明。
用户可以通过 `--checkpoint` 参数加载保存的模型并恢复训练(注意:只需要传基础文件名,不需要扩展名,例如需要加载 `model.pdparams` 那么,只需要使用 `--checkpoint=model`)。如果同一个文件夹内有一个和参数文件基础文件名相同,而后缀为 `.pdopt` 的文件,(如 `model.pdopt`,即优化器文件),那么该文件也会被自动加载。如果你想要重置优化器的状态,在训练脚本加入 `--reset-optimizer` 参数。
#### 训练模型的一部分
用户可以通过 `--train-seq2seq-only` 或者 `--train-postnet-only` 来实现固定模型的其他部分,只训练需要训练的部分。但当只训练模型的一部分时,其他的部分需要从保存的模型中加载。
当只训练模型的 `seq2seq` 部分或者 `postnet` 部分时,需要使用 `--checkpoint` 加载整个模型并保持相同的配置。注意,当只训练 `postnet` 的时候,需要保证配置中的`use_decoder_state_for_postnet_input=false`,因为在这种情况下,postnet 使用真实的 mel 频谱作为输入。注意,`use_decoder_state_for_postnet_input` 的默认值是 `True`
示例:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--output=${directory_to_save_results}
```
### 使用 GPU 多卡训练
本模型支持使用多个 GPU 通过数据并行的方式训练。方法是使用 `paddle.distributed.launch` 模块来启动 `train.py`
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_to_write_log} \
training_script ...
```
paddle.distributed.launch 通过多进程的方式进行并行训练。`--selected_gpus` 指的是选择的 GPU 的逻辑序号,`started_port` 指的是 0 号显卡的使用的端口号,`--log_dir` 是日志保存的目录,每个进程的输出会在这个文件夹中保存为单独的文件。再在后面接上需要启动的脚本文件及其参数即可。这部分和单卡训练的脚本一致,但是需要传入 `--use-data-paralle` 以使用数据并行训练。示例命令如下。
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--output=${directory_to_save_results}
```
上述的示例中,设置了 `2, 3, 4, 5` 号显卡为可见的 GPU。然后 `--selected_gpus=0,1,2,3` 选择的是 GPU 的逻辑序号,分别对应于 `2, 3, 4, 5` 号卡。
模型 (模型参数保存为`*.pdparams` 文件,优化器被保存为 `*.pdopt` 文件)保存在 `${directory_to_save_results}/checkpoints` 文件夹中。多层平均的注意力机制对齐结果被保存为 `.png` 图片,默认保存在 `${directory_to_save_results}/checkpoints/alignment_ave` 中。每一层的注意力机制对齐结果默认被保存在 `${directory_to_save_results}/checkpoints/alignment_layer{attention_layer_num}`文件夹中。默认每 10000 步保存一次用于查看。
对 6 个给定的句子的语音合成结果保存在 `${directory_to_save_results}/checkpoints/eval` 中,包含多层平均平均的注意力机制对齐结果,这被保存为名为 `step{step_num}_text{text_id}_single_alignment.png` 的图片;以及合成的音频文件,保存为名为 `step{step_num}_text{text_id}_single_predicted.wav` 的音频。
### 使用 Tensorboard 查看训练
Tensorboard 训练日志被保存在 `${directory_to_save_results}/log/` 文件夹,可以通过 tensorboard 查看。使用方法如下。
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### 从保存的模型合成语音
给定一组文本,使用 `synthesis.py` 从一个训练好的模型来合成语音,使用方法如下。
```bash
python synthesis.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}
```
示例文本文件如下:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
合成的结果包含注意力机制对齐结果和音频文件,保存于 `${dst_dir}`
### 计算 position rate
根据 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), 对于不同的数据集,会有不同的 position rate. 有两个不同的 position rate,一个用于 query 一个用于 key, 这在论文中称为 $\omega_1$ 和 $\omega_2$ ,在预设配置文件中的名字分别为 `query_position_rate``key_position_rate`
比如 LJSpeech 数据集的 `query_position_rate``key_position_rate` 分别为 `1.0``1.385`。固定 `query_position_rate` 为 1.0,`key_position_rate` 可以使用 `compute_timestamp_ratio.py` 计算,命令如下,其中 `${data_root}` 是预处理后的数据集路径。
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
可以得到如下的结果。
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
然后在预设配置文件中设置 `key_position_rate=1.385` 以及 `query_position_rate=1.0`
## 致谢
本实现包含及改写了 r9y9's 的 [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch) 中的部分文件,在此表示感谢。
from parakeet.models.deepvoice3.encoder import Encoder
from parakeet.models.deepvoice3.decoder import Decoder
from parakeet.models.deepvoice3.converter import Converter
from parakeet.models.deepvoice3.model import DeepVoice3
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
from kpi import AccKpi
each_epoch_duration_frame1_card1 = DurationKpi("each_epoch_duration_frame1_card1", 0.02, actived=True)
train_cost_frame1_card1 = CostKpi("train_cost_frame1_card1", 0.02, actived=True)
each_epoch_duration_frame4_card1 = DurationKpi("each_epoch_duration_frame4_card1", 0.05, actived=True)
train_cost_frame4_card1 = CostKpi("train_cost_frame4_card1", 0.02, actived=True)
tracking_kpis = [
each_epoch_duration_frame1_card1,
train_cost_frame1_card1,
each_epoch_duration_frame4_card1,
train_cost_frame4_card1,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
import numpy as np
from collections import namedtuple
from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F
from parakeet.modules.weight_norm import Linear
WindowRange = namedtuple("WindowRange", ["backward", "ahead"])
class Attention(dg.Layer):
def __init__(self,
query_dim,
embed_dim,
dropout=0.0,
window_range=WindowRange(-1, 3),
key_projection=True,
value_projection=True):
super(Attention, self).__init__()
self.query_proj = Linear(query_dim, embed_dim)
if key_projection:
self.key_proj = Linear(embed_dim, embed_dim)
if value_projection:
self.value_proj = Linear(embed_dim, embed_dim)
self.out_proj = Linear(embed_dim, query_dim)
self.key_projection = key_projection
self.value_projection = value_projection
self.dropout = dropout
self.window_range = window_range
def forward(self, query, encoder_out, mask=None, last_attended=None):
"""
Compute pooled context representation and alignment scores.
Args:
query (Variable): shape(B, T_dec, C_q), the query tensor,
where C_q means the channel of query.
encoder_out (Tuple(Variable, Variable)):
keys (Variable): shape(B, T_enc, C_emb), the key
representation from an encoder, where C_emb means
text embedding size.
values (Variable): shape(B, T_enc, C_emb), the value
representation from an encoder, where C_emb means
text embedding size.
mask (Variable, optional): Shape(B, T_enc), mask generated with
valid text lengths.
last_attended (int, optional): The position that received most
attention at last timestep. This is only used at decoding.
Outpus:
x (Variable): Shape(B, T_dec, C_q), the context representation
pooled from attention mechanism.
attn_scores (Variable): shape(B, T_dec, T_enc), the alignment
tensor, where T_dec means the number of decoder time steps and
T_enc means number the number of decoder time steps.
"""
keys, values = encoder_out
residual = query
if self.value_projection:
values = self.value_proj(values)
if self.key_projection:
keys = self.key_proj(keys)
x = self.query_proj(query)
# TODO: check the code
x = F.matmul(x, keys, transpose_y=True)
# mask generated by sentence length
neg_inf = -1.e30
if mask is not None:
neg_inf_mask = F.scale(F.unsqueeze(mask, [1]), neg_inf)
x += neg_inf_mask
# if last_attended is provided, focus only on a window range around it
# to enforce monotonic attention.
# TODO: if last attended is a shape(B,) array
if last_attended is not None:
locality_mask = np.ones(shape=x.shape, dtype=np.float32)
backward, ahead = self.window_range
backward = last_attended + backward
ahead = last_attended + ahead
backward = max(backward, 0)
ahead = min(ahead, x.shape[-1])
locality_mask[:, :, backward:ahead] = 0.
locality_mask = dg.to_variable(locality_mask)
neg_inf_mask = F.scale(locality_mask, neg_inf)
x += neg_inf_mask
x = F.softmax(x)
attn_scores = x
x = F.dropout(x,
self.dropout,
dropout_implementation="upscale_in_train")
x = F.matmul(x, values)
encoder_length = keys.shape[1]
# CAUTION: is it wrong? let it be now
x = F.scale(x, encoder_length * np.sqrt(1.0 / encoder_length))
x = self.out_proj(x)
x = F.scale((x + residual), np.sqrt(0.5))
return x, attn_scores
# This file was copied from https://github.com/r9y9/deepvoice3_pytorch/tree/master/audio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import librosa
import librosa.filters
import math
import numpy as np
from scipy import signal
from hparams import hparams
from scipy.io import wavfile
import lws
def load_wav(path):
return librosa.core.load(path, sr=hparams.sample_rate)[0]
def save_wav(wav, path):
wav = wav * 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
def preemphasis(x):
from nnmnkwii.preprocessing import preemphasis
return preemphasis(x, hparams.preemphasis)
def inv_preemphasis(x):
from nnmnkwii.preprocessing import inv_preemphasis
return inv_preemphasis(x, hparams.preemphasis)
def spectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
return _normalize(S)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _db_to_amp(_denormalize(spectrogram) +
hparams.ref_level_db) # Convert back to linear
processor = _lws_processor()
D = processor.run_lws(S.astype(np.float64).T**hparams.power)
y = processor.istft(D).astype(np.float32)
return inv_preemphasis(y)
def melspectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
if not hparams.allow_clipping_in_normalization:
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
return _normalize(S)
def _lws_processor():
return lws.lws(hparams.fft_size, hparams.hop_size, mode="speech")
# Conversions:
_mel_basis = None
def _linear_to_mel(spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis():
if hparams.fmax is not None:
assert hparams.fmax <= hparams.sample_rate // 2
return librosa.filters.mel(hparams.sample_rate,
hparams.fft_size,
fmin=hparams.fmin,
fmax=hparams.fmax,
n_mels=hparams.num_mels)
def _amp_to_db(x):
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _normalize(S):
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
def _denormalize(S):
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepvoice3 import DeepVoiceTTS, ConvSpec, WindowRange
def deepvoice3(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, True]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model
def deepvoice3_multispeaker(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, False]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/compute_timestamp_ratio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import io
import numpy as np
# sys.path.append("../")
from hparams import hparams, hparams_debug_string
from data import TextDataSource, MelSpecDataSource
from nnmnkwii.datasets import FileSourceDataset
from tqdm import trange
from parakeet import g2p as frontend
def build_parser():
parser = argparse.ArgumentParser(
description="Compute output/input timestamp ratio.")
parser.add_argument(
"--hparams", type=str, default="", help="Hyper parameters.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json).")
parser.add_argument("data_root", type=str, help="path of the dataset.")
return parser
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
data_root = args.data_root
preset = args.preset
# Load preset if specified
if preset is not None:
with io.open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
# Code below
X = FileSourceDataset(TextDataSource(data_root))
Mel = FileSourceDataset(MelSpecDataSource(data_root))
in_sizes = []
out_sizes = []
for i in trange(len(X)):
x, m = X[i], Mel[i]
if X.file_data_source.multi_speaker:
x = x[0]
in_sizes.append(x.shape[0])
out_sizes.append(m.shape[0])
in_sizes = np.array(in_sizes)
out_sizes = np.array(out_sizes)
input_timestamps = np.sum(in_sizes)
output_timestamps = np.sum(
out_sizes) / hparams.outputs_per_step / hparams.downsample_step
print(input_timestamps, output_timestamps,
output_timestamps / input_timestamps)
sys.exit(0)
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F
import paddle.fluid.initializer as I
from parakeet.modules.weight_norm import Conv1D, Conv1DCell, Conv2D, Linear
class Conv1DGLU(dg.Layer):
"""
A Convolution 1D block with GLU activation. It also applys dropout for the
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
"""
def __init__(self,
n_speakers,
speaker_dim,
in_channels,
num_filters,
filter_size=1,
dilation=1,
std_mul=4.0,
dropout=0.0,
causal=False,
residual=True):
super(Conv1DGLU, self).__init__()
# conv spec
self.in_channels = in_channels
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
# padding
self.causal = causal
# weight init and dropout
self.std_mul = std_mul
self.dropout = dropout
c_in = filter_size * in_channels
std = np.sqrt(std_mul * (1 - dropout) / c_in)
self.residual = residual
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channes should equals num_filters"
self.conv = Conv1DCell(in_channels,
2 * num_filters,
filter_size,
dilation,
causal,
param_attr=I.Normal(scale=std))
if n_speakers > 1:
assert (speaker_dim is not None
), "speaker embed should not be null in multi-speaker case"
std = np.sqrt(1 / speaker_dim)
self.fc = Linear(speaker_dim,
num_filters,
param_attr=I.Normal(scale=std))
def forward(self, x, speaker_embed=None):
"""
Args:
x (Variable): Shape(B, C_in, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = F.dropout(x,
self.dropout,
dropout_implementation="upscale_in_train")
x = self.conv(x)
content, gate = F.split(x, num_or_sections=2, dim=1)
if speaker_embed is not None:
sp = F.softsign(self.fc(speaker_embed))
content = F.elementwise_add(content, sp, axis=0)
# glu
x = F.sigmoid(gate) * content
if self.residual:
x = F.scale(x + residual, np.sqrt(0.5))
return x
def start_sequence(self):
self.conv.start_sequence()
def add_input(self, x_t, speaker_embed=None):
"""
Args:
x (Variable): Shape(B, C_in), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels.
speaker_embed_bct1 (Variable): Shape(B, C_sp), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x_t
x_t = F.dropout(x_t,
self.dropout,
dropout_implementation="upscale_in_train")
x_t = self.conv.add_input(x_t)
content_t, gate_t = F.split(x_t, num_or_sections=2, dim=1)
if speaker_embed is not None:
sp = F.softsign(self.fc(speaker_embed))
content_t = F.elementwise_add(content_t, sp, axis=0)
# glu
x_t = F.sigmoid(gate_t) * content_t
if self.residual:
x_t = F.scale(x_t + residual, np.sqrt(0.5))
return x_t
import numpy as np
from itertools import chain
import paddle.fluid.layers as F
import paddle.fluid.initializer as I
import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose, Linear
from parakeet.models.deepvoice3.conv1dglu import Conv1DGLU
from parakeet.models.deepvoice3.encoder import ConvSpec
def upsampling_4x_blocks(n_speakers, speaker_dim, target_channels, dropout):
# upsampling convolitions
upsampling_convolutions = [
Conv1DTranspose(target_channels,
target_channels,
2,
stride=2,
param_attr=I.Normal(np.sqrt(1 / target_channels))),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=1,
std_mul=1.,
dropout=dropout),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=3,
std_mul=4.,
dropout=dropout),
Conv1DTranspose(target_channels,
target_channels,
2,
stride=2,
param_attr=I.Normal(scale=np.sqrt(4. /
target_channels))),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=1,
std_mul=1.,
dropout=dropout),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=3,
std_mul=4.,
dropout=dropout)
]
return upsampling_convolutions
def upsampling_2x_blocks(n_speakers, speaker_dim, target_channels, dropout):
upsampling_convolutions = [
Conv1DTranspose(target_channels,
target_channels,
2,
stride=2,
param_attr=I.Normal(scale=np.sqrt(1. /
target_channels))),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=1,
std_mul=1.,
dropout=dropout),
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=3,
std_mul=4.,
dropout=dropout)
]
return upsampling_convolutions
def upsampling_1x_blocks(n_speakers, speaker_dim, target_channels, dropout):
upsampling_convolutions = [
Conv1DGLU(n_speakers,
speaker_dim,
target_channels,
target_channels,
3,
dilation=3,
std_mul=4.,
dropout=dropout)
]
return upsampling_convolutions
class Converter(dg.Layer):
"""
Vocoder that transforms mel spectrogram (or ecoder hidden states)
to waveform.
"""
def __init__(self,
n_speakers,
speaker_dim,
in_channels,
linear_dim,
convolutions=(ConvSpec(256, 5, 1), ) * 4,
time_upsampling=1,
dropout=0.0):
super(Converter, self).__init__()
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.in_channels = in_channels
self.linear_dim = linear_dim
# CAUTION: this should equals the downsampling steps coefficient
self.time_upsampling = time_upsampling
self.dropout = dropout
target_channels = convolutions[0].out_channels
# conv proj to target channels
self.first_conv_proj = Conv1D(
in_channels,
target_channels,
1,
param_attr=I.Normal(scale=np.sqrt(1 / in_channels)))
# Idea from nyanko
if time_upsampling == 4:
self.upsampling_convolutions = dg.LayerList(
upsampling_4x_blocks(n_speakers, speaker_dim, target_channels,
dropout))
elif time_upsampling == 2:
self.upsampling_convolutions = dg.LayerList(
upsampling_2x_blocks(n_speakers, speaker_dim, target_channels,
dropout))
elif time_upsampling == 1:
self.upsampling_convolutions = dg.LayerList(
upsampling_1x_blocks(n_speakers, speaker_dim, target_channels,
dropout))
else:
raise ValueError(
"Upsampling factors other than {1, 2, 4} are Not supported.")
# post conv layers
std_mul = 4.0
in_channels = target_channels
self.convolutions = dg.LayerList()
for (out_channels, filter_size, dilation) in convolutions:
if in_channels != out_channels:
std = np.sqrt(std_mul / in_channels)
# CAUTION: relu
self.convolutions.append(
Conv1D(in_channels,
out_channels,
1,
act="relu",
param_attr=I.Normal(scale=std)))
in_channels = out_channels
std_mul = 2.0
self.convolutions.append(
Conv1DGLU(n_speakers,
speaker_dim,
in_channels,
out_channels,
filter_size,
dilation=dilation,
std_mul=std_mul,
dropout=dropout))
in_channels = out_channels
std_mul = 4.0
# final conv proj, channel transformed to linear dim
std = np.sqrt(std_mul * (1 - dropout) / in_channels)
# CAUTION: sigmoid
self.last_conv_proj = Conv1D(in_channels,
linear_dim,
1,
act="sigmoid",
param_attr=I.Normal(scale=std))
def forward(self, x, speaker_embed=None):
"""
Convert mel spectrogram or decoder hidden states to linear spectrogram.
Args:
x (Variable): Shape(B, T_mel, C_in), converter inputs, where
C_in means the input channel for the converter. Note that it
can be either C_mel (channel of mel spectrogram) or C_dec // r.
When use mel_spectrogram as the input of converter, C_in =
C_mel; and when use decoder states as the input of converter,
C_in = C_dec // r. In this scenario, decoder hidden states are
treated as if they were r outputs per decoder step and are
unpacked before passing to the converter.
speaker_embed (Variable, optional): shape(B, C_sp), speaker
embedding, where C_sp means the speaker embedding size.
Returns:
out (Variable): Shape(B, T_lin, C_lin), the output linear
spectrogram, where C_lin means the channel of linear
spectrogram and T_linear means the length(time steps) of linear
spectrogram. T_line = time_upsampling * T_mel, which depends
on the time_upsampling converter.
"""
x = F.transpose(x, [0, 2, 1])
x = self.first_conv_proj(x)
if speaker_embed is not None:
speaker_embed = F.dropout(
speaker_embed,
self.dropout,
dropout_implementation="upscale_in_train")
for layer in chain(self.upsampling_convolutions, self.convolutions):
if isinstance(layer, Conv1DGLU):
x = layer(x, speaker_embed)
else:
x = layer(x)
out = self.last_conv_proj(x)
out = F.transpose(out, [0, 2, 1])
return out
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import io
import platform
from os.path import dirname, join
from nnmnkwii.datasets import FileSourceDataset, FileDataSource
from os.path import join, expanduser
import random
# import global hyper parameters
from hparams import hparams
from parakeet import g2p as frontend
import builder
_frontend = getattr(frontend, hparams.frontend)
def _pad(seq, max_len, constant_values=0):
return np.pad(seq, (0, max_len - len(seq)),
mode="constant",
constant_values=constant_values)
def _pad_2d(x, max_len, b_pad=0):
x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)],
mode="constant",
constant_values=0)
return x
class TextDataSource(FileDataSource):
def __init__(self, data_root, speaker_id=None):
self.data_root = data_root
self.speaker_ids = None
self.multi_speaker = False
# If not None, filter by speaker_id
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with io.open(meta, "rt", encoding="utf-8") as f:
lines = f.readlines()
l = lines[0].split("|")
assert len(l) == 4 or len(l) == 5
self.multi_speaker = len(l) == 5
texts = list(map(lambda l: l.split("|")[3], lines))
if self.multi_speaker:
speaker_ids = list(map(lambda l: int(l.split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
if self.speaker_id is not None:
indices = np.array(speaker_ids) == self.speaker_id
texts = list(np.array(texts)[indices])
self.multi_speaker = False
return texts
return texts, speaker_ids
else:
return texts
def collect_features(self, *args):
if self.multi_speaker:
text, speaker_id = args
else:
text = args[0]
global _frontend
if _frontend is None:
_frontend = getattr(frontend, hparams.frontend)
seq = _frontend.text_to_sequence(
text, p=hparams.replace_pronunciation_prob)
if platform.system() == "Windows":
if hasattr(hparams, "gc_probability"):
_frontend = None # memory leaking prevention in Windows
if np.random.rand() < hparams.gc_probability:
gc.collect() # garbage collection enforced
print("GC done")
if self.multi_speaker:
return np.asarray(seq, dtype=np.int32), int(speaker_id)
else:
return np.asarray(seq, dtype=np.int32)
class _NPYDataSource(FileDataSource):
def __init__(self, data_root, col, speaker_id=None):
self.data_root = data_root
self.col = col
self.frame_lengths = []
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with io.open(meta, "rt", encoding="utf-8") as f:
lines = f.readlines()
l = lines[0].split("|")
assert len(l) == 4 or len(l) == 5
multi_speaker = len(l) == 5
self.frame_lengths = list(map(lambda l: int(l.split("|")[2]), lines))
paths = list(map(lambda l: l.split("|")[self.col], lines))
paths = list(map(lambda f: join(self.data_root, f), paths))
if multi_speaker and self.speaker_id is not None:
speaker_ids = list(map(lambda l: int(l.split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
indices = np.array(speaker_ids) == self.speaker_id
paths = list(np.array(paths)[indices])
self.frame_lengths = list(np.array(self.frame_lengths)[indices])
# aha, need to cast numpy.int64 to int
self.frame_lengths = list(map(int, self.frame_lengths))
return paths
def collect_features(self, path):
return np.load(path)
class MelSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(MelSpecDataSource, self).__init__(data_root, 1, speaker_id)
class LinearSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(LinearSpecDataSource, self).__init__(data_root, 0, speaker_id)
class PartialyRandomizedSimilarTimeLengthSampler(object):
"""Partially randmoized sampler
1. Sort by lengths
2. Pick a small patch and randomize it
3. Permutate mini-batchs
"""
def __init__(self,
lengths,
batch_size=16,
batch_group_size=None,
permutate=True):
self.sorted_indices = np.argsort(lengths)
self.lengths = np.array(lengths)[self.sorted_indices]
self.batch_size = batch_size
if batch_group_size is None:
batch_group_size = min(batch_size * 32, len(self.lengths))
if batch_group_size % batch_size != 0:
batch_group_size -= batch_group_size % batch_size
self.batch_group_size = batch_group_size
assert batch_group_size % batch_size == 0
self.permutate = permutate
def __iter__(self):
indices = self.sorted_indices.copy()
batch_group_size = self.batch_group_size
s, e = 0, 0
for i in range(len(indices) // batch_group_size):
s = i * batch_group_size
e = s + batch_group_size
random.shuffle(indices[s:e])
# Permutate batches
if self.permutate:
perm = np.arange(len(indices[:e]) // self.batch_size)
random.shuffle(perm)
indices[:e] = indices[:e].reshape(
-1, self.batch_size)[perm, :].reshape(-1)
# Handle last elements
s += batch_group_size
if s < len(indices):
random.shuffle(indices[s:])
return iter(indices)
def __len__(self):
return len(self.sorted_indices)
class Dataset(object):
def __init__(self, X, Mel, Y):
self.X = X
self.Mel = Mel
self.Y = Y
# alias
self.multi_speaker = X.file_data_source.multi_speaker
def __getitem__(self, idx):
if self.multi_speaker:
text, speaker_id = self.X[idx]
return text, self.Mel[idx], self.Y[idx], speaker_id
else:
return self.X[idx], self.Mel[idx], self.Y[idx]
def __len__(self):
return len(self.X)
def make_loader(dataset, batch_size, shuffle, sampler, create_batch_fn,
trainer_count, local_rank):
assert not (
shuffle and
sampler), "shuffle and sampler should not be valid in the same time."
num_samples = len(dataset)
def wrapper():
if sampler is None:
ids = range(num_samples)
if shuffle:
random.shuffle(ids)
else:
ids = sampler
batch, batches = [], []
for idx in ids:
batch.append(dataset[idx])
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
batches = []
if len(batch) > 0:
batches.append(batch)
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
return wrapper
def create_batch(batch):
"""Create batch"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
multi_speaker = len(batch[0]) == 4
# Lengths
input_lengths = [len(x[0]) for x in batch]
max_input_len = max(input_lengths)
input_lengths = np.array(input_lengths, dtype=np.int64)
target_lengths = [len(x[1]) for x in batch]
max_target_len = max(target_lengths)
target_lengths = np.array(target_lengths, dtype=np.int64)
if max_target_len % (r * downsample_step) != 0:
max_target_len += (r * downsample_step) - max_target_len % (
r * downsample_step)
assert max_target_len % (r * downsample_step) == 0
# Set 0 for zero beginning padding
# imitates initial decoder states
b_pad = r
max_target_len += b_pad * downsample_step
x_batch = np.array(
[_pad(x[0], max_input_len) for x in batch], dtype=np.int64)
x_batch = np.expand_dims(x_batch, axis=-1)
mel_batch = np.array(
[_pad_2d(
x[1], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
# down sampling is done here
if downsample_step > 1:
mel_batch = mel_batch[:, 0::downsample_step, :]
mel_batch = np.expand_dims(np.transpose(mel_batch, axes=[0, 2, 1]), axis=2)
y_batch = np.array(
[_pad_2d(
x[2], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
y_batch = np.expand_dims(np.transpose(y_batch, axes=[0, 2, 1]), axis=2)
# text positions
text_positions = np.array(
[_pad(np.arange(1, len(x[0]) + 1), max_input_len) for x in batch],
dtype=np.int64)
text_positions = np.expand_dims(text_positions, axis=-1)
max_decoder_target_len = max_target_len // r // downsample_step
# frame positions
s, e = 1, max_decoder_target_len + 1
frame_positions = np.tile(
np.expand_dims(
np.arange(
s, e, dtype=np.int64), axis=0), (len(batch), 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
# done flags
done = np.array([
_pad(
np.zeros(
len(x[1]) // r // downsample_step - 1, dtype=np.float32),
max_decoder_target_len,
constant_values=1) for x in batch
])
done = np.expand_dims(np.expand_dims(done, axis=1), axis=1)
if multi_speaker:
speaker_ids = np.expand_dims(np.array([x[3] for x in batch]), axis=-1)
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths, speaker_ids)
else:
speaker_ids = None
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths)
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
from hparams import hparams, hparams_debug_string
from parakeet import g2p as frontend
from deepvoice3 import DeepVoiceTTS
def dry_run(model):
"""
Run the model once, just to get it initialized.
"""
model.train()
_frontend = getattr(frontend, hparams.frontend)
batch_size = 4
enc_length = 157
snd_sample_length = 500
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
n_speakers = hparams.n_speakers
# make sure snd_sample_length can be divided by r * downsample_step
linear_shift = r * downsample_step
snd_sample_length += linear_shift - snd_sample_length % linear_shift
decoder_length = snd_sample_length // downsample_step // r
mel_length = snd_sample_length // downsample_step
n_vocab = _frontend.n_vocab
max_pos = hparams.max_positions
spker_embed = hparams.speaker_embed_dim
linear_dim = model.linear_dim
mel_dim = hparams.num_mels
x = np.random.randint(
low=0, high=n_vocab, size=(batch_size, enc_length, 1), dtype="int64")
input_lengths = np.arange(
enc_length - batch_size + 1, enc_length + 1, dtype="int64")
mel = np.random.randn(batch_size, mel_dim, 1, mel_length).astype("float32")
y = np.random.randn(batch_size, linear_dim, 1,
snd_sample_length).astype("float32")
text_positions = np.tile(
np.arange(
0, enc_length, dtype="int64"), (batch_size, 1))
text_mask = text_positions > np.expand_dims(input_lengths, 1)
text_positions[text_mask] = 0
text_positions = np.expand_dims(text_positions, axis=-1)
frame_positions = np.tile(
np.arange(
1, decoder_length + 1, dtype="int64"), (batch_size, 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
done = np.zeros(shape=(batch_size, 1, 1, decoder_length), dtype="float32")
target_lengths = np.array([snd_sample_length] * batch_size).astype("int64")
speaker_ids = np.random.randint(
low=0, high=n_speakers, size=(batch_size, 1),
dtype="int64") if n_speakers > 1 else None
ismultispeaker = speaker_ids is not None
x = dg.to_variable(x)
input_lengths = dg.to_variable(input_lengths)
mel = dg.to_variable(mel)
y = dg.to_variable(y)
text_positions = dg.to_variable(text_positions)
frame_positions = dg.to_variable(frame_positions)
done = dg.to_variable(done)
target_lengths = dg.to_variable(target_lengths)
speaker_ids = dg.to_variable(
speaker_ids) if speaker_ids is not None else None
# these two fields are used as numpy ndarray
text_lengths = input_lengths.numpy()
decoder_lengths = target_lengths.numpy() // r // downsample_step
max_seq_len = max(text_lengths.max(), decoder_lengths.max())
if max_seq_len >= hparams.max_positions:
raise RuntimeError(
"max_seq_len ({}) >= max_posision ({})\n"
"Input text or decoder targget length exceeded the maximum length.\n"
"Please set a larger value for ``max_position`` in hyper parameters."
.format(max_seq_len, hparams.max_positions))
# cause paddle's embedding layer expect shape[-1] == 1
# first dry run runs the whole model
mel_outputs, linear_outputs, attn, done_hat = model(
x, input_lengths, mel, speaker_ids, text_positions, frame_positions)
num_parameters = 0
for k, v in model.state_dict().items():
print("{}|{}|{}".format(k, v.shape, np.prod(v.shape)))
num_parameters += np.prod(v.shape)
print("now model has {} parameters".format(len(model.state_dict())))
print("now model has {} elements".format(num_parameters))
import numpy as np
from collections import namedtuple
import paddle.fluid.layers as F
import paddle.fluid.initializer as I
import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import Conv1D, Linear
from parakeet.models.deepvoice3.conv1dglu import Conv1DGLU
ConvSpec = namedtuple("ConvSpec", ["out_channels", "filter_size", "dilation"])
class Encoder(dg.Layer):
def __init__(self,
n_vocab,
embed_dim,
n_speakers,
speaker_dim,
padding_idx=None,
embedding_weight_std=0.1,
convolutions=(ConvSpec(64, 5, 1), ) * 7,
max_positions=512,
dropout=0.):
super(Encoder, self).__init__()
self.embedding_weight_std = embedding_weight_std
self.embed = dg.Embedding(
(n_vocab, embed_dim),
padding_idx=padding_idx,
param_attr=I.Normal(scale=embedding_weight_std))
self.dropout = dropout
if n_speakers > 1:
std = np.sqrt((1 - dropout) / speaker_dim) # CAUTION: keep_prob
self.sp_proj1 = Linear(speaker_dim,
embed_dim,
param_attr=I.Normal(scale=std))
self.sp_proj2 = Linear(speaker_dim,
embed_dim,
param_attr=I.Normal(scale=std))
self.n_speakers = n_speakers
self.convolutions = dg.LayerList()
in_channels = embed_dim
std_mul = 1.0
for (out_channels, filter_size, dilation) in convolutions:
# 1 * 1 convolution & relu
if in_channels != out_channels:
std = np.sqrt(std_mul / in_channels)
self.convolutions.append(
Conv1D(in_channels,
out_channels,
1,
act="relu",
param_attr=I.Normal(scale=std)))
in_channels = out_channels
std_mul = 2.0
self.convolutions.append(
Conv1DGLU(n_speakers,
speaker_dim,
in_channels,
out_channels,
filter_size,
dilation,
std_mul,
dropout,
causal=False,
residual=True))
in_channels = out_channels
std_mul = 4.0
std = np.sqrt(std_mul * (1 - dropout) / in_channels)
self.convolutions.append(
Conv1D(in_channels, embed_dim, 1, param_attr=I.Normal(scale=std)))
def forward(self, x, speaker_embed=None):
"""
Encode text sequence.
Args:
x (Variable): Shape(B, T_enc), dtype: int64. Ihe input text
indices. T_enc means the timesteps of decoder input x.
speaker_embed (Variable, optional): Shape(batch_size, speaker_dim),
dtype: float32. Speaker embeddings. This arg is not None only
when the model is a multispeaker model.
Returns:
keys (Variable), Shape(B, T_enc, C_emb), the encoded
representation for keys, where C_emb menas the text embedding
size.
values (Variable), Shape(B, T_enc, C_emb), the encoded
representation for values.
"""
x = self.embed(x)
x = F.dropout(x,
self.dropout,
dropout_implementation="upscale_in_train")
x = F.transpose(x, [0, 2, 1])
if self.n_speakers > 1 and speaker_embed is not None:
speaker_embed = F.dropout(
speaker_embed,
self.dropout,
dropout_implementation="upscale_in_train")
x = F.elementwise_add(x,
F.softsign(self.sp_proj1(speaker_embed)),
axis=0)
input_embed = x
for layer in self.convolutions:
if isinstance(layer, Conv1DGLU):
x = layer(x, speaker_embed)
else:
# layer is a Conv1D with (1,) filter wrapped by WeightNormWrapper
x = layer(x)
if self.n_speakers > 1 and speaker_embed is not None:
x = F.elementwise_add(x,
F.softsign(self.sp_proj2(speaker_embed)),
axis=0)
keys = x # (B, C, T)
values = F.scale(input_embed + x, scale=np.sqrt(0.5))
keys = F.transpose(keys, [0, 2, 1])
values = F.transpose(values, [0, 2, 1])
return keys, values
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
from os.path import join, expanduser
from warnings import warn
from datetime import datetime
import matplotlib
# Force matplotlib not to use any Xwindows backend.
matplotlib.use("Agg")
from matplotlib import pyplot as plt
from matplotlib import cm
import audio
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import librosa.display
from tensorboardX import SummaryWriter
# import global hyper parameters
from hparams import hparams
from parakeet import g2p as frontend
_frontend = getattr(frontend, hparams.frontend)
def tts(model, text, p=0., speaker_id=None):
"""
Convert text to speech waveform given a deepvoice3 model.
Args:
model (DeepVoiceTTS): Model used to synthesize waveform.
text (str) : Input text to be synthesized
p (float) : Replace word to pronounciation if p > 0. Default is 0.
Returns:
waveform (numpy.ndarray): Shape(T_wav, ), predicted wave form, where
T_wav means the length of the synthesized wave form.
alignment (numpy.ndarray): Shape(T_dec, T_enc), predicted alignment
matrix, where T_dec means the time steps of decoder outputs, T_enc
means the time steps of encoder outoputs.
spectrogram (numpy.ndarray): Shape(T_lin, C_lin), predicted linear
spectrogram, where T__lin means the time steps of linear
spectrogram and C_lin mean sthe channels of linear spectrogram.
mel (numpy.ndarray): Shape(T_mel, C_mel), predicted mel spectrogram,
where T_mel means the time steps of mel spectrogram and C_mel means
the channels of mel spectrogram.
"""
model.eval()
sequence = np.array(_frontend.text_to_sequence(text, p=p)).astype("int64")
sequence = np.reshape(sequence, (1, -1, 1))
text_positions = np.arange(1, sequence.shape[1] + 1, dtype="int64")
text_positions = np.reshape(text_positions, (1, -1, 1))
sequence = dg.to_variable(sequence)
text_positions = dg.to_variable(text_positions)
speaker_ids = None if speaker_id is None else fluid.layers.fill_constant(
shape=[1, 1], value=speaker_id)
# sequence: shape(1, input_length, 1)
# text_positions: shape(1, input_length, 1)
# Greedy decoding
mel_outputs, linear_outputs, alignments, done = model.transduce(
sequence, text_positions, speaker_ids)
# reshape to the desired shape
linear_output = linear_outputs.numpy().squeeze().T
spectrogram = audio._denormalize(linear_output)
alignment = alignments.numpy()[0]
mel = mel_outputs.numpy().squeeze().T
mel = audio._denormalize(mel)
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
return waveform, alignment, spectrogram, mel
def prepare_spec_image(spectrogram):
"""
Prepare an image from spectrogram to be written to tensorboardX
summary writer.
Args:
spectrogram (numpy.ndarray): Shape(T, C), spectrogram to be
visualized, where T means the time steps of the spectrogram,
and C means the channels of the spectrogram.
Return:
np.ndarray: Shape(C, T, 4), the generated image of the spectrogram,
where T means the time steps of the spectrogram. It is treated
as the width of the image. And C means the channels of the
spectrogram, which is treated as the height of the image. And 4
means it is a 'ARGB' format.
"""
# [0, 1]
spectrogram = (spectrogram - np.min(spectrogram)) / (
np.max(spectrogram) - np.min(spectrogram))
spectrogram = np.flip(spectrogram, axis=1) # flip against freq axis
return np.uint8(cm.magma(spectrogram.T) * 255)
def plot_alignment(alignment, path, info=None):
fig, ax = plt.subplots()
im = ax.imshow(
alignment, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
plt.savefig(path, format="png")
plt.close()
def time_string():
return datetime.now().strftime("%Y-%m-%d %H:%M")
def save_alignment(global_step, path, attn):
plot_alignment(
attn.T,
path,
info="{}, {}, step={}".format(hparams.builder,
time_string(), global_step))
def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
# hard coded text sequences
texts = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
eval_output_dir = join(checkpoint_dir, "eval")
if not os.path.exists(eval_output_dir):
os.makedirs(eval_output_dir)
print("[eval] Evaluating the model, results are saved in {}".format(
eval_output_dir))
model.eval()
# hard coded
speaker_ids = [0, 1, 10] if ismultispeaker else [None]
for speaker_id in speaker_ids:
speaker_str = ("multispeaker{}".format(speaker_id)
if speaker_id is not None else "single")
for idx, text in enumerate(texts):
signal, alignment, _, mel = tts(model,
text,
p=0,
speaker_id=speaker_id)
signal /= np.max(np.abs(signal))
# Alignment
path = join(eval_output_dir,
"step{:09d}_text{}_{}_alignment.png".format(
global_step, idx, speaker_str))
save_alignment(global_step, path, alignment)
tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats='HWC')
# Mel
writer.add_image(
"(Eval) Predicted mel spectrogram text{}_{}".format(
idx, speaker_str),
prepare_spec_image(mel),
global_step,
dataformats='HWC')
# Audio
path = join(eval_output_dir,
"step{:09d}_text{}_{}_predicted.wav".format(
global_step, idx, speaker_str))
audio.save_wav(signal, path)
try:
writer.add_audio(
"(Eval) Predicted audio signal {}_{}".format(idx,
speaker_str),
signal,
global_step,
sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
def save_states(global_step,
writer,
mel_outputs,
linear_outputs,
attn,
mel,
y,
input_lengths,
checkpoint_dir=None):
"""
Save states for the trainning process.
"""
print("[train] Saving intermediate states at step {}".format(global_step))
idx = min(1, len(input_lengths) - 1)
input_length = input_lengths[idx]
# Alignment, Multi-hop attention
if attn is not None and len(attn.shape) == 4:
attn = attn.numpy()
for i in range(attn.shape[0]):
alignment = attn[i]
alignment = alignment[idx]
tag = "alignment_layer{}".format(i + 1)
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats='HWC')
alignment_dir = join(checkpoint_dir,
"alignment_layer{}".format(i + 1))
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
path = join(
alignment_dir,
"step{:09d}_layer_{}_alignment.png".format(global_step, i + 1))
save_alignment(global_step, path, alignment)
alignment_dir = join(checkpoint_dir, "alignment_ave")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
path = join(alignment_dir,
"step{:09d}_alignment.png".format(global_step))
alignment = np.mean(attn, axis=0)[idx]
save_alignment(global_step, path, alignment)
tag = "averaged_alignment"
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats="HWC")
if mel_outputs is not None:
mel_output = mel_outputs[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image(
"Predicted mel spectrogram",
mel_output,
global_step,
dataformats="HWC")
if linear_outputs is not None:
linear_output = linear_outputs[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image(
"Predicted linear spectrogram",
spectrogram,
global_step,
dataformats="HWC")
signal = audio.inv_spectrogram(linear_output.T)
signal /= np.max(np.abs(signal))
path = join(checkpoint_dir,
"step{:09d}_predicted.wav".format(global_step))
try:
writer.add_audio(
"Predicted audio signal",
signal,
global_step,
sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
audio.save_wav(signal, path)
if mel_outputs is not None:
mel_output = mel[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image(
"Target mel spectrogram",
mel_output,
global_step,
dataformats="HWC")
if linear_outputs is not None:
linear_output = y[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image(
"Target linear spectrogram",
spectrogram,
global_step,
dataformats="HWC")
Source: hparam.py copied from tensorflow v1.12.0.
https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
with the following:
wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/hparams.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import hparam_tf.hparam
# NOTE: If you want full control for model architecture. please take a look
# at the code and change whatever you want. Some hyper parameters are hardcoded.
# Default hyperparameters:
hparams = hparam_tf.hparam.HParams(
name="deepvoice3",
# Text:
# [en, jp]
frontend='en',
# Replace words to its pronunciation with fixed probability.
# e.g., 'hello' to 'HH AH0 L OW1'
# [en, jp]
# en: Word -> pronunciation using CMUDict
# jp: Word -> pronounciation usnig MeCab
# [0 ~ 1.0]: 0 means no replacement happens.
replace_pronunciation_prob=0.5,
# Convenient model builder
# [deepvoice3, deepvoice3_multispeaker, nyanko]
# Definitions can be found at deepvoice3_pytorch/builder.py
# deepvoice3: DeepVoice3 https://arxiv.org/abs/1710.07654
# deepvoice3_multispeaker: Multi-speaker version of DeepVoice3
# nyanko: https://arxiv.org/abs/1710.08969
builder="deepvoice3",
# Must be configured depends on the dataset and model you use
n_speakers=1,
speaker_embed_dim=16,
# Audio:
num_mels=80,
fmin=125,
fmax=7600,
fft_size=1024,
hop_size=256,
sample_rate=22050,
preemphasis=0.97,
min_level_db=-100,
ref_level_db=20,
# whether to rescale waveform or not.
# Let x is an input waveform, rescaled waveform y is given by:
# y = x / np.abs(x).max() * rescaling_max
rescaling=False,
rescaling_max=0.999,
# mel-spectrogram is normalized to [0, 1] for each utterance and clipping may
# happen depends on min_level_db and ref_level_db, causing clipping noise.
# If False, assertion is added to ensure no clipping happens.
allow_clipping_in_normalization=True,
# Model:
downsample_step=4, # must be 4 when builder="nyanko"
outputs_per_step=1, # must be 1 when builder="nyanko"
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
padding_idx=0,
# Maximum number of input text length
# try setting larger value if you want to give very long text input
max_positions=512,
dropout=1 - 0.95,
kernel_size=3,
text_embed_dim=128,
encoder_channels=256,
decoder_channels=256,
# Note: large converter channels requires significant computational cost
converter_channels=256,
query_position_rate=1.0,
# can be computed by `compute_timestamp_ratio.py`.
key_position_rate=1.385, # 2.37 for jsut
key_projection=False,
value_projection=False,
use_memory_mask=True,
trainable_positional_encodings=False,
freeze_embedding=False,
# If True, use decoder's internal representation for postnet inputs,
# otherwise use mel-spectrogram.
use_decoder_state_for_postnet_input=True,
# Data loader
random_seed=1234,
pin_memory=True,
# Set it to 1 when in Windows (MemoryError, THAllocator.c 0x5)
num_workers=2,
# Loss
masked_loss_weight=0.5, # (1-w)*loss + w * masked_loss
# heuristic: priotrize [0 ~ priotiry_freq] for linear loss
priority_freq=3000,
priority_freq_weight=0.0, # (1-w)*linear_loss + w*priority_linear_loss
# https://arxiv.org/pdf/1710.08969.pdf
# Adding the divergence to the loss stabilizes training, expecially for
# very deep (> 10 layers) networks.
# Binary div loss seems has approx 10x scale compared to L1 loss, so I choose 0.1.
binary_divergence_weight=0.1, # set 0 to disable
use_guided_attention=True,
guided_attention_sigma=0.2,
# Training:
batch_size=16,
adam_beta1=0.5,
adam_beta2=0.9,
adam_eps=1e-6,
amsgrad=False,
initial_learning_rate=5e-4, # 0.001,
lr_schedule="noam_learning_rate_decay",
lr_schedule_kwargs={},
nepochs=2000,
weight_decay=0.0,
clip_thresh=0.1,
# Save
checkpoint_interval=10000,
eval_interval=10000,
save_optimizer_state=True,
# Eval:
# this can be list for multple layers of attention
# e.g., [True, False, False, False, True]
force_monotonic_attention=True,
# Attention constraint for incremental decoding
window_ahead=3,
# 0 tends to prevent word repretetion, but sometime causes skip words
window_backward=1,
power=1.4, # Power to raise magnitudes to prior to phase retrieval
# GC:
# Forced garbage collection probability
# Use only when MemoryError continues in Windows (Disabled by default)
#gc_probability = 0.001,
# json_meta mode only
# 0: "use all",
# 1: "ignore only unmatched_alignment",
# 2: "fully ignore recognition",
ignore_recognition_level=2,
# when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
min_text=20,
# if true, data without phoneme alignment file(.lab) will be ignored
process_only_htk_aligned=False)
def hparams_debug_string():
values = hparams.values()
hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
return 'Hyperparameters:\n' + '\n'.join(hp)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
export LD_LIBRARY_PATH=/fluid13_workspace/cuda-9.0/lib64/:/fluid13_workspace/cudnnv7.5_cuda9.0/lib64/:$LD_LIBRARY_PATH
#export PYTHONPATH=/dv3_workspace/paddle_for_dv3/build/python/
export PYTHONPATH=/fluid13_workspace/paddle_cherry_pick/build/python/:../
export CUDA_VISIBLE_DEVICES=7
GLOG_v=0 python -u train.py \
--use-gpu \
--reset-optimizer \
--preset=presets/deepvoice3_ljspeech.json \
--checkpoint-dir=checkpoint_single_1014 \
--data-root="/fluid13_workspace/dv3_workspace/deepvoice3_pytorch/data/ljspeech/" \
--hparams="batch_size=16"
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册