From e656ca4783c5fb6ef1f4dd1c0982b45257d0d950 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 13 Aug 2020 10:18:28 +0800 Subject: [PATCH] add assert raises in the test_retain_graph UT. (#25983) --- .../paddle/fluid/tests/unittests/test_retain_graph.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py index 53fde086dd..db4b922afc 100644 --- a/python/paddle/fluid/tests/unittests/test_retain_graph.py +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -20,7 +20,7 @@ import unittest paddle.disable_static() SEED = 2020 np.random.seed(SEED) -fluid.default_main_program().random_seed = SEED +paddle.manual_seed(SEED) class Generator(fluid.dygraph.Layer): @@ -90,7 +90,7 @@ class TestRetainGraph(unittest.TestCase): else: return 0.0, None - def test_retain(self): + def run_retain(self, need_retain): g = Generator() d = Discriminator() @@ -117,7 +117,7 @@ class TestRetainGraph(unittest.TestCase): d, realA, fakeB, lambda_gp=10.0) loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty - loss_d.backward(retain_graph=True) + loss_d.backward(retain_graph=need_retain) optim_d.minimize(loss_d) optim_g.clear_gradients() @@ -130,6 +130,11 @@ class TestRetainGraph(unittest.TestCase): loss_g.backward() optim_g.minimize(loss_g) + def test_retain(self): + self.run_retain(need_retain=True) + self.assertRaises( + fluid.core.EnforceNotMet, self.run_retain, need_retain=False) + if __name__ == '__main__': unittest.main() -- GitLab