From 19bd8b8e6752eb74f428d66234e49e4965682c47 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 26 Sep 2020 21:21:46 +0530 Subject: [PATCH] tracking hooks --- labml_nn/gan/mnist.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/labml_nn/gan/mnist.py b/labml_nn/gan/mnist.py index 20c24597..18274cad 100644 --- a/labml_nn/gan/mnist.py +++ b/labml_nn/gan/mnist.py @@ -6,13 +6,14 @@ import torch.nn as nn import torch.utils.data from torchvision import transforms +import labml.utils.pytorch as pytorch_utils from labml import tracker, monit, experiment from labml.configs import option, calculate from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs +from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs, Mode from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss plt.rcParams['image.interpolation'] = 'nearest' @@ -71,6 +72,9 @@ class GANBatchStep(BatchStepProtocol): self.discriminator_loss = discriminator_loss self.generator_optimizer = generator_optimizer self.discriminator_optimizer = discriminator_optimizer + + hook_model_outputs(self.generator, 'generator') + hook_model_outputs(self.discriminator, 'discriminator') tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.discriminator.*", True) tracker.set_image("generated", True, 1 / 100) @@ -99,6 +103,8 @@ class GANBatchStep(BatchStepProtocol): tracker.add("loss.generator.", loss) if MODE_STATE.is_train: loss.backward() + if MODE_STATE.is_log_parameters: + pytorch_utils.store_model_indicators(self.generator, 'generator') self.generator_optimizer.step() with monit.section("discriminator"): @@ -114,6 +120,8 @@ class GANBatchStep(BatchStepProtocol): tracker.add("loss.discriminator.", loss) if MODE_STATE.is_train: loss.backward() + if MODE_STATE.is_log_parameters: + pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') self.discriminator_optimizer.step() return {'samples': len(data)}, None @@ -163,8 +171,8 @@ def _discriminator_optimizer(c: Configs): opt_conf = OptimizerConfigs() opt_conf.optimizer = 'Adam' opt_conf.parameters = c.discriminator.parameters() - opt_conf.learning_rate = 2.5e-4 - opt_conf.betas = (0.5, 0.999) + opt_conf.learning_rate = 2.5e-5 + # opt_conf.betas = (0.5, 0.999) return opt_conf @@ -173,8 +181,8 @@ def _generator_optimizer(c: Configs): opt_conf = OptimizerConfigs() opt_conf.optimizer = 'Adam' opt_conf.parameters = c.generator.parameters() - opt_conf.learning_rate = 2.5e-4 - opt_conf.betas = (0.5, 0.999) + opt_conf.learning_rate = 2.5e-5 + # opt_conf.betas = (0.5, 0.999) return opt_conf -- GitLab