From e5b23785fe67d99fed19d4190070ef5df9e4992d Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 11 Jan 2018 16:09:49 +0800 Subject: [PATCH] Fix random seed of dynamic rnn gradient check --- .../fluid/tests/test_dynrnn_gradient_check.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py index c02c59284..cdbf582a3 100644 --- a/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py +++ b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py @@ -197,7 +197,24 @@ class BaseRNN(object): return numpy.array([o.mean() for o in outs.itervalues()]).mean() -class TestSimpleMul(unittest.TestCase): +class SeedFixedTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Fix random seeds to remove randomness from tests""" + cls._np_rand_state = numpy.random.get_state() + cls._py_rand_state = random.getstate() + + numpy.random.seed(123) + random.seed(124) + + @classmethod + def tearDownClass(cls): + """Restore random seeds""" + numpy.random.set_state(cls._np_rand_state) + random.setstate(cls._py_rand_state) + + +class TestSimpleMul(SeedFixedTestCase): DATA_NAME = 'X' DATA_WIDTH = 32 PARAM_NAME = 'W' @@ -263,7 +280,7 @@ class TestSimpleMul(unittest.TestCase): self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05)) -class TestSimpleMulWithMemory(unittest.TestCase): +class TestSimpleMulWithMemory(SeedFixedTestCase): DATA_WIDTH = 32 HIDDEN_WIDTH = 20 DATA_NAME = 'X' -- GitLab