diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index b5d45464977ec7c66be17d1089bb26127bb10ea2..5a0909e86300d17b882a22a338f489a93ed0d6ab 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -18,7 +18,7 @@ You can choose to install via pypi or clone the repository and install manually. 1. Install via pypi. ```bash - pip install parakeet + pip install paddle-parakeet ``` 2. Install manually. @@ -102,6 +102,19 @@ optional arguments: 5. `--device` is the device (gpu id) to use for training. `-1` means CPU. +example script: + +```bash +python train.py --config=./ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 +``` + +You can monitor training log via tensorboard, using the script below. + +```bash +cd experiment/log +tensorboard --logdir=. +``` + ## Synthesis ```text usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path @@ -127,3 +140,9 @@ optional arguments: 4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence. 5. `--device` is the device (gpu id) to use for training. `-1` means CPU. +example script: + +```bash +python synthesis.py --config=./ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated +``` + diff --git a/examples/deepvoice3/ljspeech.yaml b/examples/deepvoice3/ljspeech.yaml index bd17d2ecd175da75926d552178015ce1fee357c8..8aa6b5a63734343f44a3545709feb687e3797153 100644 --- a/examples/deepvoice3/ljspeech.yaml +++ b/examples/deepvoice3/ljspeech.yaml @@ -85,7 +85,6 @@ train: batch_size: 16 epochs: 2000 - report_interval: 100 snap_interval: 1000 eval_interval: 10000 save_interval: 10000 diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index ee64fea2bbca801e2aa4643d0cbb37d9054b7bf7..636032dab570c0f94172d8a75239e989b978549e 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -22,7 +22,7 @@ from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.utils.layer_tools import summary from data import LJSpeechMetaData, DataCollector, Transform -from utils import make_model, eval_model, plot_alignment, plot_alignments, save_state, make_output_tree +from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -176,6 +176,11 @@ if __name__ == "__main__": parameter_list=dv3.parameters()) gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1) + # generation + synthesis_config = config["synthesis"] + power = synthesis_config["power"] + n_iter = synthesis_config["n_iter"] + # =========================link(dataloader, paddle)========================= # CAUTION: it does not return a DataLoader loader = fluid.io.DataLoader.from_generator(capacity=10, @@ -198,16 +203,14 @@ if __name__ == "__main__": # =========================train========================= epoch = train_config["epochs"] - report_interval = train_config["report_interval"] snap_interval = train_config["snap_interval"] save_interval = train_config["save_interval"] eval_interval = train_config["eval_interval"] global_step = 1 - average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0} for j in range(1, 1 + epoch): - epoch_loss = {"mel": 0., "lin": 0., "done": 0., "attn": 0.} + epoch_loss = 0. for i, batch in tqdm.tqdm(enumerate(loader, 1)): dv3.train() # CAUTION: don't forget to switch to train (text_sequences, text_lengths, text_positions, mel_specs, @@ -225,7 +228,7 @@ if __name__ == "__main__": losses = criterion(mel_outputs, linear_outputs, done, alignments, downsampled_mel_specs, lin_specs, done_flags, text_lengths, frames) - l = criterion.compose_loss(losses) + l = losses["loss"] l.backward() # record learning rate before updating writer.add_scalar("learning_rate", @@ -235,41 +238,31 @@ if __name__ == "__main__": optim.clear_gradients() # ==================all kinds of tedious things================= - for k in epoch_loss.keys(): - epoch_loss[k] += losses[k].numpy()[0] - average_loss[k] += losses[k].numpy()[0] - # record step loss into tensorboard + epoch_loss += l.numpy()[0] step_loss = {k: v.numpy()[0] for k, v in losses.items()} - print(step_loss) for k, v in step_loss.items(): writer.add_scalar(k, v, global_step) # TODO: clean code # train state saving, the first sentence in the batch if global_step % snap_interval == 0: - linear_outputs_np = linear_outputs.numpy()[0].T - denoramlized = np.clip(linear_outputs_np, 0, 1) \ - * (-min_level_db) \ - + min_level_db - lin_scaled = np.exp( - (denoramlized + ref_level_db) / 20 * np.log(10)) - synthesis_config = config["synthesis"] - power = synthesis_config["power"] - n_iter = synthesis_config["n_iter"] - wav = librosa.griffinlim(lin_scaled**power, - n_iter=n_iter, - hop_length=hop_length, - win_length=win_length) - save_state(state_dir, + writer, global_step, - mel_input=mel_specs.numpy()[0].T, - mel_output=mel_outputs.numpy()[0].T, - lin_input=lin_specs.numpy()[0].T, - lin_output=linear_outputs.numpy()[0].T, - alignments=alignments.numpy()[:, 0, :, :], - wav=wav) + mel_input=downsampled_mel_specs, + mel_output=mel_outputs, + lin_input=lin_specs, + lin_output=linear_outputs, + alignments=alignments, + win_length=win_length, + hop_length=hop_length, + min_level_db=min_level_db, + ref_level_db=ref_level_db, + power=power, + n_iter=n_iter, + preemphasis=preemphasis, + sample_rate=sample_rate) # evaluation if global_step % eval_interval == 0: @@ -291,28 +284,31 @@ if __name__ == "__main__": state_dir, "waveform", "eval_sample_{:09d}.wav".format(global_step)) sf.write(wav_path, wav, sample_rate) + writer.add_audio("eval_sample_{}".format(idx), + wav, + global_step, + sample_rate=sample_rate) attn_path = os.path.join( state_dir, "alignments", "eval_sample_attn_{:09d}.png".format(global_step)) plot_alignment(attn, attn_path) + writer.add_image("eval_sample_attn{}".format(idx), + cm.viridis(attn), + global_step, + dataformats="HWC") # save checkpoint if global_step % save_interval == 0: - dg.save_dygraph(dv3.state_dict(), - os.path.join(ckpt_dir, "dv3")) - dg.save_dygraph(optim.state_dict(), - os.path.join(ckpt_dir, "dv3")) - - # report average loss - if global_step % report_interval == 0: - for k in epoch_loss.keys(): - average_loss[k] /= report_interval - print("[average_loss] ", - "global_step: {}".format(global_step), average_loss) - average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0} + dg.save_dygraph( + dv3.state_dict(), + os.path.join(ckpt_dir, + "model_step_{}".format(global_step))) + dg.save_dygraph( + optim.state_dict(), + os.path.join(ckpt_dir, + "model_step_{}".format(global_step))) global_step += 1 # epoch report - for k in epoch_loss.keys(): - epoch_loss[k] /= i - print("[epoch_loss] ", "epoch: {}".format(j), epoch_loss) \ No newline at end of file + writer.add_scalar("epoch_average_loss", epoch_loss / i, j) + epoch_loss = 0. \ No newline at end of file diff --git a/examples/deepvoice3/utils.py b/examples/deepvoice3/utils.py index 4e9f5cfcfd4cdbae32816b05b7859b954318bbc1..02118af8a62bad0502fdd05be267737205af31fe 100644 --- a/examples/deepvoice3/utils.py +++ b/examples/deepvoice3/utils.py @@ -1,5 +1,6 @@ import os import numpy as np +from matplotlib import cm import matplotlib.pyplot as plt import librosa from scipy import signal @@ -125,21 +126,32 @@ def eval_model(model, text, replace_pronounciation_prob, min_level_db, model.eval() mel_outputs, linear_outputs, alignments, done = model.transduce( dg.to_variable(text), dg.to_variable(text_positions)) + linear_outputs_np = linear_outputs.numpy()[0].T # (C, T) + wav = spec_to_waveform(linear_outputs_np, min_level_db, ref_level_db, + power, n_iter, win_length, hop_length, preemphasis) + alignments_np = alignments.numpy()[0] # batch_size = 1 print("linear_outputs's shape: ", linear_outputs_np.shape) + print("alignmnets' shape:", alignments.shape) + return wav, alignments_np + - denoramlized = np.clip(linear_outputs_np, 0, - 1) * (-min_level_db) + min_level_db +def spec_to_waveform(spec, min_level_db, ref_level_db, power, n_iter, + win_length, hop_length, preemphasis): + """Convert output linear spec to waveform using griffin-lim vocoder. + + Args: + spec (ndarray): the output linear spectrogram, shape(C, T), where C means n_fft, T means frames. + """ + denoramlized = np.clip(spec, 0, 1) * (-min_level_db) + min_level_db lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10)) wav = librosa.griffinlim(lin_scaled**power, n_iter=n_iter, hop_length=hop_length, win_length=win_length) - wav = signal.lfilter([1.], [1., -preemphasis], wav) - - print("alignmnets' shape:", alignments.shape) - alignments_np = alignments.numpy()[0].T - return wav, alignments_np + if preemphasis > 0: + wav = signal.lfilter([1.], [1., -preemphasis], wav) + return wav def make_output_tree(output_dir): @@ -157,88 +169,89 @@ def make_output_tree(output_dir): os.makedirs(p) -def plot_alignment(alignment, path, info=None): +def plot_alignment(alignment, path): """ Plot an attention layer's alignment for a sentence. - alignment: shape(T_enc, T_dec), and T_enc is flipped + alignment: shape(T_dec, T_enc). """ - fig, ax = plt.subplots() - im = ax.imshow(alignment, - aspect='auto', - origin='lower', - interpolation='none') - fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' - if info is not None: - xlabel += '\n\n' + info - plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') - plt.tight_layout() + plt.figure() + plt.imshow(alignment) + plt.colorbar() + plt.xlabel('Encoder timestep') + plt.ylabel('Decoder timestep') plt.savefig(path) plt.close() -def plot_alignments(alignments, save_dir, global_step): - """ - Plot alignments for a sentence when training, we just pick the first - sentence. Each layer is plot separately. - alignments: shape(N, T_dec, T_enc) - """ - n_layers = alignments.shape[0] - for i, alignment in enumerate(alignments): - alignment = alignment.T - - path = os.path.join(save_dir, "layer_{}".format(i)) - if not os.path.exists(path): - os.makedirs(path) - fname = os.path.join(path, "step_{:09d}".format(global_step)) - plot_alignment(alignment, fname) - - average_alignment = np.mean(alignments, axis=0).T - path = os.path.join(save_dir, "average") - if not os.path.exists(path): - os.makedirs(path) - fname = os.path.join(path, "step_{:09d}.png".format(global_step)) - plot_alignment(average_alignment, fname) - - def save_state(save_dir, + writer, global_step, mel_input=None, mel_output=None, lin_input=None, lin_output=None, alignments=None, - wav=None): + win_length=1024, + hop_length=256, + min_level_db=-100, + ref_level_db=20, + power=1.4, + n_iter=32, + preemphasis=0.97, + sample_rate=22050): + """Save training intermediate results. Save states for the first sentence in the batch, including + mel_spec(predicted, target), lin_spec(predicted, target), attn, waveform. + + Args: + save_dir (str): directory to save results. + writer (SummaryWriter): tensorboardX summary writer + global_step (int): global step. + mel_input (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel) + mel_output (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel) + lin_input (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin) + lin_output (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin) + alignments (Variable, optional): Defaults to None. Shape(N, B, T_dec, C_enc) + wav ([type], optional): Defaults to None. [description] + """ if mel_input is not None and mel_output is not None: - path = os.path.join(save_dir, "mel_spec") - if not os.path.exists(path): - os.makedirs(path) + mel_input = mel_input[0].numpy().T + mel_output = mel_output[0].numpy().T + path = os.path.join(save_dir, "mel_spec") plt.figure(figsize=(10, 3)) display.specshow(mel_input) plt.colorbar() plt.title("mel_input") plt.savefig( os.path.join(path, - "target_mel_spec_step{:09d}".format(global_step))) + "target_mel_spec_step{:09d}.png".format(global_step))) plt.close() + writer.add_image("target/mel_spec", + cm.viridis(mel_input), + global_step, + dataformats="HWC") + plt.figure(figsize=(10, 3)) display.specshow(mel_output) plt.colorbar() - plt.title("mel_input") + plt.title("mel_output") plt.savefig( - os.path.join(path, - "predicted_mel_spec_step{:09d}".format(global_step))) + os.path.join( + path, "predicted_mel_spec_step{:09d}.png".format(global_step))) plt.close() + writer.add_image("predicted/mel_spec", + cm.viridis(mel_output), + global_step, + dataformats="HWC") + if lin_input is not None and lin_output is not None: + lin_input = lin_input[0].numpy().T + lin_output = lin_output[0].numpy().T path = os.path.join(save_dir, "lin_spec") - if not os.path.exists(path): - os.makedirs(path) plt.figure(figsize=(10, 3)) display.specshow(lin_input) @@ -246,28 +259,50 @@ def save_state(save_dir, plt.title("mel_input") plt.savefig( os.path.join(path, - "target_lin_spec_step{:09d}".format(global_step))) + "target_lin_spec_step{:09d}.png".format(global_step))) plt.close() + writer.add_image("target/lin_spec", + cm.viridis(lin_input), + global_step, + dataformats="HWC") + plt.figure(figsize=(10, 3)) display.specshow(lin_output) plt.colorbar() plt.title("mel_input") plt.savefig( - os.path.join(path, - "predicted_lin_spec_step{:09d}".format(global_step))) + os.path.join( + path, "predicted_lin_spec_step{:09d}.png".format(global_step))) plt.close() - if alignments is not None and len(alignments.shape) == 3: - path = os.path.join(save_dir, "alignments") - if not os.path.exists(path): - os.makedirs(path) - plot_alignments(alignments, path, global_step) + writer.add_image("predicted/lin_spec", + cm.viridis(lin_output), + global_step, + dataformats="HWC") - if wav is not None: + if alignments is not None and len(alignments.shape) == 4: + path = os.path.join(save_dir, "alignments") + alignments = alignments[:, 0, :, :].numpy() + for idx, attn_layer in enumerate(alignments): + save_path = os.path.join( + path, + "train_attn_layer_{}_step_{}.png".format(idx, global_step)) + plot_alignment(attn_layer, save_path) + + writer.add_image("train_attn/layer_{}".format(idx), + cm.viridis(attn_layer), + global_step, + dataformats="HWC") + + if lin_output is not None: + wav = spec_to_waveform(lin_output, min_level_db, ref_level_db, power, + n_iter, win_length, hop_length, preemphasis) path = os.path.join(save_dir, "waveform") - if not os.path.exists(path): - os.makedirs(path) - sf.write( - os.path.join(path, "sample_step_{:09d}.wav".format(global_step)), - wav, 22050) + save_path = os.path.join( + path, "train_sample_step_{:09d}.wav".format(global_step)) + sf.write(save_path, wav, sample_rate) + writer.add_audio("train_sample", + wav, + global_step, + sample_rate=sample_rate) diff --git a/parakeet/models/deepvoice3/decoder.py b/parakeet/models/deepvoice3/decoder.py index 09c62f6572481a5d3d226b66d645752769455281..8e6a46b4b43f4fde99f6202e722def0df2561884 100644 --- a/parakeet/models/deepvoice3/decoder.py +++ b/parakeet/models/deepvoice3/decoder.py @@ -79,25 +79,26 @@ def unfold_adjacent_frames(folded_frames, r): class Decoder(dg.Layer): - def __init__(self, - n_speakers, - speaker_dim, - embed_dim, - mel_dim, - r=1, - max_positions=512, - padding_idx=None, - preattention=(ConvSpec(128, 5, 1), ) * 4, - convolutions=(ConvSpec(128, 5, 1), ) * 4, - attention=True, - dropout=0.0, - use_memory_mask=False, - force_monotonic_attention=False, - query_position_rate=1.0, - key_position_rate=1.0, - window_range=WindowRange(-1, 3), - key_projection=True, - value_projection=True): + def __init__( + self, + n_speakers, + speaker_dim, + embed_dim, + mel_dim, + r=1, + max_positions=512, + padding_idx=None, # remove it! + preattention=(ConvSpec(128, 5, 1), ) * 4, + convolutions=(ConvSpec(128, 5, 1), ) * 4, + attention=True, + dropout=0.0, + use_memory_mask=False, + force_monotonic_attention=False, + query_position_rate=1.0, + key_position_rate=1.0, + window_range=WindowRange(-1, 3), + key_projection=True, + value_projection=True): super(Decoder, self).__init__() self.dropout = dropout @@ -109,21 +110,23 @@ class Decoder(dg.Layer): self.n_speakers = n_speakers conv_channels = convolutions[0].out_channels + # only when padding idx is 0 can we easilt handle it self.embed_keys_positions = PositionEmbedding(max_positions, embed_dim, - padding_idx=padding_idx) + padding_idx=0) self.embed_query_positions = PositionEmbedding(max_positions, conv_channels, - padding_idx=padding_idx) + padding_idx=0) if n_speakers > 1: - # CAUTION: mind the sigmoid std = np.sqrt((1 - dropout) / speaker_dim) self.speaker_proj1 = Linear(speaker_dim, 1, + act="sigmoid", param_attr=I.Normal(scale=std)) self.speaker_proj2 = Linear(speaker_dim, 1, + act="sigmoid", param_attr=I.Normal(scale=std)) # prenet @@ -168,6 +171,7 @@ class Decoder(dg.Layer): ] * len(convolutions) else: self.force_monotonic_attention = force_monotonic_attention + for x, y in zip(self.force_monotonic_attention, self.attention): if x is True and y is False: raise ValueError("When not using attention, there is no " @@ -249,7 +253,7 @@ class Decoder(dg.Layer): text_positions (Variable): shape(B, T_enc), dtype: int64. Positions indices for text inputs for the encoder, where T_enc means the encoder timesteps. - frame_positions (Variable): shape(B, T_dec // r), dtype: + frame_positions (Variable): shape(B, T_mel // r), dtype: int64. Positions indices for each decoder time steps. speaker_embed: shape(batch_size, speaker_dim), speaker embedding, only used for multispeaker model. @@ -287,16 +291,14 @@ class Decoder(dg.Layer): if text_positions is not None: w = self.key_position_rate if self.n_speakers > 1: - w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)), - [-1]) + w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1]) text_pos_embed = self.embed_keys_positions(text_positions, w) keys += text_pos_embed # (B, T, C) if frame_positions is not None: w = self.query_position_rate if self.n_speakers > 1: - w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)), - [-1]) + w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1]) frame_pos_embed = self.embed_query_positions(frame_positions, w) else: frame_pos_embed = None @@ -387,8 +389,7 @@ class Decoder(dg.Layer): w = self.key_position_rate if self.n_speakers > 1: # shape (B, ) - w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)), - [-1]) + w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1]) text_pos_embed = self.embed_keys_positions(text_positions, w) keys += text_pos_embed # (B, T, C) @@ -417,8 +418,7 @@ class Decoder(dg.Layer): dtype="int64") w = self.query_position_rate if self.n_speakers > 1: - w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)), - [-1]) + w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1]) # (B, T=1, C) frame_pos_embed = self.embed_query_positions(frame_pos, w) diff --git a/parakeet/models/deepvoice3/encoder.py b/parakeet/models/deepvoice3/encoder.py index a50ae8328ea54147d390508e96d4877647c8dc0a..ebcd62fcae8bf2f9e739f33948996760010b6226 100644 --- a/parakeet/models/deepvoice3/encoder.py +++ b/parakeet/models/deepvoice3/encoder.py @@ -35,9 +35,11 @@ class Encoder(dg.Layer): std = np.sqrt((1 - dropout) / speaker_dim) self.sp_proj1 = Linear(speaker_dim, embed_dim, + act="softsign", param_attr=I.Normal(scale=std)) self.sp_proj2 = Linear(speaker_dim, embed_dim, + act="softsign", param_attr=I.Normal(scale=std)) self.n_speakers = n_speakers @@ -104,9 +106,7 @@ class Encoder(dg.Layer): speaker_embed, self.dropout, dropout_implementation="upscale_in_train") - x = F.elementwise_add(x, - F.softsign(self.sp_proj1(speaker_embed)), - axis=0) + x = F.elementwise_add(x, self.sp_proj1(speaker_embed), axis=0) input_embed = x for layer in self.convolutions: @@ -117,9 +117,7 @@ class Encoder(dg.Layer): x = layer(x) if self.n_speakers > 1 and speaker_embed is not None: - x = F.elementwise_add(x, - F.softsign(self.sp_proj2(speaker_embed)), - axis=0) + x = F.elementwise_add(x, self.sp_proj2(speaker_embed), axis=0) keys = x # (B, C, T) values = F.scale(input_embed + x, scale=np.sqrt(0.5)) diff --git a/parakeet/models/deepvoice3/loss.py b/parakeet/models/deepvoice3/loss.py index 0832c07e7186d62bfd5cca6758c791ba0db5bfc5..86412e7b6776f60dda7feb05df63851318a5660b 100644 --- a/parakeet/models/deepvoice3/loss.py +++ b/parakeet/models/deepvoice3/loss.py @@ -156,8 +156,9 @@ class TTSLoss(object): compute_mel_loss=True, compute_done_loss=True, compute_attn_loss=True): + total_loss = 0. + # n_frames # mel_lengths # decoder_lengths - # 4 个 loss 吧。lin(l1, bce, lin), mel(l1, bce, mel), attn, done max_frames = lin_hyp.shape[1] max_mel_steps = max_frames // self.downsample_factor max_decoder_steps = max_mel_steps // self.r @@ -182,6 +183,7 @@ class TTSLoss(object): lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask) lin_loss = self.binary_divergence_weight * lin_bce_loss \ + (1 - self.binary_divergence_weight) * lin_l1_loss + total_loss += lin_loss if compute_mel_loss: mel_hyp = mel_hyp[:, :-self.time_shift, :] @@ -192,32 +194,28 @@ class TTSLoss(object): # print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0]) mel_loss = self.binary_divergence_weight * mel_bce_loss \ + (1 - self.binary_divergence_weight) * mel_l1_loss + total_loss += mel_loss if compute_attn_loss: attn_loss = self.attention_loss( attn_hyp, input_lengths.numpy(), n_frames.numpy() // (self.downsample_factor * self.r)) + total_loss += attn_loss if compute_done_loss: done_loss = self.done_loss(done_hyp, done_ref) + total_loss += done_loss result = { - "mel": mel_loss if compute_mel_loss else None, - "mel_l1_loss": mel_l1_loss if compute_mel_loss else None, - "mel_bce_loss": mel_bce_loss if compute_mel_loss else None, - "lin": lin_loss if compute_lin_loss else None, - "lin_l1_loss": lin_l1_loss if compute_lin_loss else None, - "lin_bce_loss": lin_bce_loss if compute_lin_loss else None, + "loss": total_loss, + "mel/mel_loss": mel_loss if compute_mel_loss else None, + "mel/l1_loss": mel_l1_loss if compute_mel_loss else None, + "mel/bce_loss": mel_bce_loss if compute_mel_loss else None, + "lin/lin_loss": lin_loss if compute_lin_loss else None, + "lin/l1_loss": lin_l1_loss if compute_lin_loss else None, + "lin/bce_loss": lin_bce_loss if compute_lin_loss else None, "done": done_loss if compute_done_loss else None, "attn": attn_loss if compute_attn_loss else None, } return result - - @staticmethod - def compose_loss(result): - total_loss = 0. - for k in ["mel", "lin", "done", "attn"]: - if result[k] is not None: - total_loss += result[k] - return total_loss \ No newline at end of file diff --git a/parakeet/modules/weight_norm.py b/parakeet/modules/weight_norm.py index 6532cd8193156e1e0516844073255d4014f7f887..8db21c033c5e2238c52bafda69a83b13091c8030 100644 --- a/parakeet/modules/weight_norm.py +++ b/parakeet/modules/weight_norm.py @@ -6,7 +6,6 @@ import paddle.fluid.layers as F from parakeet.modules import customized as L -# TODO: just use numpy to init weight norm wrappers def norm(param, dim, power): powered = F.pow(param, power) powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False) @@ -73,7 +72,7 @@ class WeightNormWrapper(dg.Layer): w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype)) F.assign(temp, getattr(self, w_g)) - # also set this + # also set this when setting up setattr( self.layer, self.param_name, compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,