tracking hooks

上级 32e7df1e
......@@ -6,13 +6,14 @@ import torch.nn as nn
import torch.utils.data
from torchvision import transforms
import labml.utils.pytorch as pytorch_utils
from labml import tracker, monit, experiment
from labml.configs import option, calculate
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs, Mode
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
plt.rcParams['image.interpolation'] = 'nearest'
......@@ -71,6 +72,9 @@ class GANBatchStep(BatchStepProtocol):
self.discriminator_loss = discriminator_loss
self.generator_optimizer = generator_optimizer
self.discriminator_optimizer = discriminator_optimizer
hook_model_outputs(self.generator, 'generator')
hook_model_outputs(self.discriminator, 'discriminator')
tracker.set_scalar("loss.generator.*", True)
tracker.set_scalar("loss.discriminator.*", True)
tracker.set_image("generated", True, 1 / 100)
......@@ -99,6 +103,8 @@ class GANBatchStep(BatchStepProtocol):
tracker.add("loss.generator.", loss)
if MODE_STATE.is_train:
loss.backward()
if MODE_STATE.is_log_parameters:
pytorch_utils.store_model_indicators(self.generator, 'generator')
self.generator_optimizer.step()
with monit.section("discriminator"):
......@@ -114,6 +120,8 @@ class GANBatchStep(BatchStepProtocol):
tracker.add("loss.discriminator.", loss)
if MODE_STATE.is_train:
loss.backward()
if MODE_STATE.is_log_parameters:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
self.discriminator_optimizer.step()
return {'samples': len(data)}, None
......@@ -163,8 +171,8 @@ def _discriminator_optimizer(c: Configs):
opt_conf = OptimizerConfigs()
opt_conf.optimizer = 'Adam'
opt_conf.parameters = c.discriminator.parameters()
opt_conf.learning_rate = 2.5e-4
opt_conf.betas = (0.5, 0.999)
opt_conf.learning_rate = 2.5e-5
# opt_conf.betas = (0.5, 0.999)
return opt_conf
......@@ -173,8 +181,8 @@ def _generator_optimizer(c: Configs):
opt_conf = OptimizerConfigs()
opt_conf.optimizer = 'Adam'
opt_conf.parameters = c.generator.parameters()
opt_conf.learning_rate = 2.5e-4
opt_conf.betas = (0.5, 0.999)
opt_conf.learning_rate = 2.5e-5
# opt_conf.betas = (0.5, 0.999)
return opt_conf
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册