未验证 提交 12355ccc 编写于 作者: Z Zeng Jinle 提交者: GitHub

add clear_gradients to star gan ut, test=develop (#23296)

上级 a582f105
...@@ -482,6 +482,13 @@ class DyGraphTrainModel(object): ...@@ -482,6 +482,13 @@ class DyGraphTrainModel(object):
self.backward_strategy = fluid.dygraph.BackwardStrategy() self.backward_strategy = fluid.dygraph.BackwardStrategy()
self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient
def clear_gradients(self):
if self.g_optimizer:
self.g_optimizer.clear_gradients()
if self.d_optimizer:
self.d_optimizer.clear_gradients()
def run(self, image_real, label_org, label_trg): def run(self, image_real, label_org, label_trg):
image_real = fluid.dygraph.to_variable(image_real) image_real = fluid.dygraph.to_variable(image_real)
label_org = fluid.dygraph.to_variable(label_org) label_org = fluid.dygraph.to_variable(label_org)
...@@ -493,7 +500,8 @@ class DyGraphTrainModel(object): ...@@ -493,7 +500,8 @@ class DyGraphTrainModel(object):
g_loss.backward(self.backward_strategy) g_loss.backward(self.backward_strategy)
if self.g_optimizer: if self.g_optimizer:
self.g_optimizer.minimize(g_loss) self.g_optimizer.minimize(g_loss)
self.generator.clear_gradients()
self.clear_gradients()
d_loss = get_discriminator_loss(image_real, label_org, label_trg, d_loss = get_discriminator_loss(image_real, label_org, label_trg,
self.generator, self.discriminator, self.generator, self.discriminator,
...@@ -501,7 +509,8 @@ class DyGraphTrainModel(object): ...@@ -501,7 +509,8 @@ class DyGraphTrainModel(object):
d_loss.backward(self.backward_strategy) d_loss.backward(self.backward_strategy)
if self.d_optimizer: if self.d_optimizer:
self.d_optimizer.minimize(d_loss) self.d_optimizer.minimize(d_loss)
self.discriminator.clear_gradients()
self.clear_gradients()
return g_loss.numpy()[0], d_loss.numpy()[0] return g_loss.numpy()[0], d_loss.numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册