From 2933eb7e5785adf6ef4f2480039e8a8c7dfd8f12 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Fri, 14 Aug 2020 17:48:24 +0800 Subject: [PATCH] replace add_scalar to add_scalars --- examples/transformer_tts/train_transformer.py | 67 +++++-------------- examples/transformer_tts/train_vocoder.py | 2 +- 2 files changed, 19 insertions(+), 50 deletions(-) diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index a0ca16b..3499a5f 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -29,41 +29,6 @@ from parakeet.models.transformer_tts import TransformerTTS from parakeet.utils import io -def add_scalars(self, main_tag, tag_scalar_dict, step, walltime=None): - """Add scalars to vdl record file. - Args: - main_tag (string): The parent name for the tags - tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values - step (int): Step of scalars - walltime (float): Wall time of scalars. - Example: - for index in range(1, 101): - writer.add_scalar(tag="train/loss", value=index*0.2, step=index) - writer.add_scalar(tag="train/lr", value=index*0.5, step=index) - """ - import time - from visualdl.writer.record_writer import RecordFileWriter - from visualdl.component.base_component import scalar - - fw_logdir = self.logdir - walltime = round(time.time()) if walltime is None else walltime - for tag, value in tag_scalar_dict.items(): - tag = os.path.join(fw_logdir, main_tag, tag) - if '%' in tag: - raise RuntimeError("% can't appear in tag!") - if tag in self._all_writers: - fw = self._all_writers[tag] - else: - fw = RecordFileWriter( - logdir=tag, - max_queue_size=self._max_queue, - flush_secs=self._flush_secs, - filename_suffix=self._filename_suffix) - self._all_writers.update({tag: fw}) - fw.add_record( - scalar(tag=main_tag, value=value, step=step, walltime=walltime)) - - def add_config_options_to_parser(parser): parser.add_argument("--config", type=str, help="path of the config file") parser.add_argument("--use_gpu", type=int, default=0, help="device to use") @@ -99,7 +64,6 @@ def main(args): writer = LogWriter(os.path.join(args.output, 'log')) if local_rank == 0 else None - writer.add_scalars = add_scalars fluid.enable_dygraph(place) network_cfg = cfg['network'] @@ -167,23 +131,28 @@ def main(args): loss = loss + stop_loss if local_rank == 0: - writer.add_scalars('training_loss', { - 'mel_loss': mel_loss.numpy(), - 'post_mel_loss': post_mel_loss.numpy() - }, global_step) - + writer.add_scalar('training_loss/mel_loss', + mel_loss.numpy(), + global_step) + writer.add_scalar('training_loss/post_mel_loss', + post_mel_loss.numpy(), + global_step) writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) if parallel: - writer.add_scalars('alphas', { - 'encoder_alpha': model._layers.encoder.alpha.numpy(), - 'decoder_alpha': model._layers.decoder.alpha.numpy(), - }, global_step) + writer.add_scalar('alphas/encoder_alpha', + model._layers.encoder.alpha.numpy(), + global_step) + writer.add_scalar('alphas/decoder_alpha', + model._layers.decoder.alpha.numpy(), + global_step) else: - writer.add_scalars('alphas', { - 'encoder_alpha': model.encoder.alpha.numpy(), - 'decoder_alpha': model.decoder.alpha.numpy(), - }, global_step) + writer.add_scalar('alphas/encoder_alpha', + model.encoder.alpha.numpy(), + global_step) + writer.add_scalar('alphas/decoder_alpha', + model.decoder.alpha.numpy(), + global_step) writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), diff --git a/examples/transformer_tts/train_vocoder.py b/examples/transformer_tts/train_vocoder.py index 4b95f31..ccea796 100644 --- a/examples/transformer_tts/train_vocoder.py +++ b/examples/transformer_tts/train_vocoder.py @@ -121,7 +121,7 @@ def main(args): model.clear_gradients() if local_rank == 0: - writer.add_scalars('training_loss', {'loss': loss.numpy(), }, + writer.add_scalar('training_loss/loss', loss.numpy(), global_step) # save checkpoint -- GitLab