未验证 提交 e656ca47 编写于 作者: Z Zhen Wang 提交者: GitHub

add assert raises in the test_retain_graph UT. (#25983)

上级 7165f484
......@@ -20,7 +20,7 @@ import unittest
paddle.disable_static()
SEED = 2020
np.random.seed(SEED)
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
class Generator(fluid.dygraph.Layer):
......@@ -90,7 +90,7 @@ class TestRetainGraph(unittest.TestCase):
else:
return 0.0, None
def test_retain(self):
def run_retain(self, need_retain):
g = Generator()
d = Discriminator()
......@@ -117,7 +117,7 @@ class TestRetainGraph(unittest.TestCase):
d, realA, fakeB, lambda_gp=10.0)
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
loss_d.backward(retain_graph=True)
loss_d.backward(retain_graph=need_retain)
optim_d.minimize(loss_d)
optim_g.clear_gradients()
......@@ -130,6 +130,11 @@ class TestRetainGraph(unittest.TestCase):
loss_g.backward()
optim_g.minimize(loss_g)
def test_retain(self):
self.run_retain(need_retain=True)
self.assertRaises(
fluid.core.EnforceNotMet, self.run_retain, need_retain=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册