diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 7bd026daa9f2b86fdb4a219bced62d5e5f7b8051..afb113ed941e91b0d5cdf50b938704c3f23cda44 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -482,6 +482,13 @@ class DyGraphTrainModel(object): self.backward_strategy = fluid.dygraph.BackwardStrategy() self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient + def clear_gradients(self): + if self.g_optimizer: + self.g_optimizer.clear_gradients() + + if self.d_optimizer: + self.d_optimizer.clear_gradients() + def run(self, image_real, label_org, label_trg): image_real = fluid.dygraph.to_variable(image_real) label_org = fluid.dygraph.to_variable(label_org) @@ -493,7 +500,8 @@ class DyGraphTrainModel(object): g_loss.backward(self.backward_strategy) if self.g_optimizer: self.g_optimizer.minimize(g_loss) - self.generator.clear_gradients() + + self.clear_gradients() d_loss = get_discriminator_loss(image_real, label_org, label_trg, self.generator, self.discriminator, @@ -501,7 +509,8 @@ class DyGraphTrainModel(object): d_loss.backward(self.backward_strategy) if self.d_optimizer: self.d_optimizer.minimize(d_loss) - self.discriminator.clear_gradients() + + self.clear_gradients() return g_loss.numpy()[0], d_loss.numpy()[0]