未验证 提交 3bbff25b 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7453 from reyoung/feature/fix_seed_for_dynrnn_test

Fix random seed of dynamic rnn gradient check
...@@ -197,7 +197,24 @@ class BaseRNN(object): ...@@ -197,7 +197,24 @@ class BaseRNN(object):
return numpy.array([o.mean() for o in outs.itervalues()]).mean() return numpy.array([o.mean() for o in outs.itervalues()]).mean()
class TestSimpleMul(unittest.TestCase): class SeedFixedTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Fix random seeds to remove randomness from tests"""
cls._np_rand_state = numpy.random.get_state()
cls._py_rand_state = random.getstate()
numpy.random.seed(123)
random.seed(124)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
numpy.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
class TestSimpleMul(SeedFixedTestCase):
DATA_NAME = 'X' DATA_NAME = 'X'
DATA_WIDTH = 32 DATA_WIDTH = 32
PARAM_NAME = 'W' PARAM_NAME = 'W'
...@@ -263,7 +280,7 @@ class TestSimpleMul(unittest.TestCase): ...@@ -263,7 +280,7 @@ class TestSimpleMul(unittest.TestCase):
self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05)) self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05))
class TestSimpleMulWithMemory(unittest.TestCase): class TestSimpleMulWithMemory(SeedFixedTestCase):
DATA_WIDTH = 32 DATA_WIDTH = 32
HIDDEN_WIDTH = 20 HIDDEN_WIDTH = 20
DATA_NAME = 'X' DATA_NAME = 'X'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册