From 12355cccfdfbe7293b94f484f56225af241259b3 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 30 Mar 2020 20:41:22 -0500 Subject: [PATCH] add clear_gradients to star gan ut, test=develop (#23296) --- ...est_imperative_star_gan_with_gradient_penalty.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 7bd026daa9f..afb113ed941 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -482,6 +482,13 @@ class DyGraphTrainModel(object): self.backward_strategy = fluid.dygraph.BackwardStrategy() self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient + def clear_gradients(self): + if self.g_optimizer: + self.g_optimizer.clear_gradients() + + if self.d_optimizer: + self.d_optimizer.clear_gradients() + def run(self, image_real, label_org, label_trg): image_real = fluid.dygraph.to_variable(image_real) label_org = fluid.dygraph.to_variable(label_org) @@ -493,7 +500,8 @@ class DyGraphTrainModel(object): g_loss.backward(self.backward_strategy) if self.g_optimizer: self.g_optimizer.minimize(g_loss) - self.generator.clear_gradients() + + self.clear_gradients() d_loss = get_discriminator_loss(image_real, label_org, label_trg, self.generator, self.discriminator, @@ -501,7 +509,8 @@ class DyGraphTrainModel(object): d_loss.backward(self.backward_strategy) if self.d_optimizer: self.d_optimizer.minimize(d_loss) - self.discriminator.clear_gradients() + + self.clear_gradients() return g_loss.numpy()[0], d_loss.numpy()[0] -- GitLab