提交 e5b23785 编写于 作者: Y Yang Yu

Fix random seed of dynamic rnn gradient check

上级 9867a379
......@@ -197,7 +197,24 @@ class BaseRNN(object):
return numpy.array([o.mean() for o in outs.itervalues()]).mean()
class TestSimpleMul(unittest.TestCase):
class SeedFixedTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Fix random seeds to remove randomness from tests"""
cls._np_rand_state = numpy.random.get_state()
cls._py_rand_state = random.getstate()
numpy.random.seed(123)
random.seed(124)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
numpy.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
class TestSimpleMul(SeedFixedTestCase):
DATA_NAME = 'X'
DATA_WIDTH = 32
PARAM_NAME = 'W'
......@@ -263,7 +280,7 @@ class TestSimpleMul(unittest.TestCase):
self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05))
class TestSimpleMulWithMemory(unittest.TestCase):
class TestSimpleMulWithMemory(SeedFixedTestCase):
DATA_WIDTH = 32
HIDDEN_WIDTH = 20
DATA_NAME = 'X'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册