未验证 提交 12355ccc 编写于 作者: Z Zeng Jinle 提交者: GitHub

add clear_gradients to star gan ut, test=develop (#23296)

上级 a582f105
......@@ -482,6 +482,13 @@ class DyGraphTrainModel(object):
self.backward_strategy = fluid.dygraph.BackwardStrategy()
self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient
def clear_gradients(self):
if self.g_optimizer:
self.g_optimizer.clear_gradients()
if self.d_optimizer:
self.d_optimizer.clear_gradients()
def run(self, image_real, label_org, label_trg):
image_real = fluid.dygraph.to_variable(image_real)
label_org = fluid.dygraph.to_variable(label_org)
......@@ -493,7 +500,8 @@ class DyGraphTrainModel(object):
g_loss.backward(self.backward_strategy)
if self.g_optimizer:
self.g_optimizer.minimize(g_loss)
self.generator.clear_gradients()
self.clear_gradients()
d_loss = get_discriminator_loss(image_real, label_org, label_trg,
self.generator, self.discriminator,
......@@ -501,7 +509,8 @@ class DyGraphTrainModel(object):
d_loss.backward(self.backward_strategy)
if self.d_optimizer:
self.d_optimizer.minimize(d_loss)
self.discriminator.clear_gradients()
self.clear_gradients()
return g_loss.numpy()[0], d_loss.numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册