diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py index 53fde086dd23e681bd79ec663a7acb82759193bc..db4b922afcd230c852a4859ec5a7e7497d59ffff 100644 --- a/python/paddle/fluid/tests/unittests/test_retain_graph.py +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -20,7 +20,7 @@ import unittest paddle.disable_static() SEED = 2020 np.random.seed(SEED) -fluid.default_main_program().random_seed = SEED +paddle.manual_seed(SEED) class Generator(fluid.dygraph.Layer): @@ -90,7 +90,7 @@ class TestRetainGraph(unittest.TestCase): else: return 0.0, None - def test_retain(self): + def run_retain(self, need_retain): g = Generator() d = Discriminator() @@ -117,7 +117,7 @@ class TestRetainGraph(unittest.TestCase): d, realA, fakeB, lambda_gp=10.0) loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty - loss_d.backward(retain_graph=True) + loss_d.backward(retain_graph=need_retain) optim_d.minimize(loss_d) optim_g.clear_gradients() @@ -130,6 +130,11 @@ class TestRetainGraph(unittest.TestCase): loss_g.backward() optim_g.minimize(loss_g) + def test_retain(self): + self.run_retain(need_retain=True) + self.assertRaises( + fluid.core.EnforceNotMet, self.run_retain, need_retain=False) + if __name__ == '__main__': unittest.main()