提交 4dfc10cc 编写于 作者: P Peng Li

a patch for fixing random seeds in gradient checkers

上级 7cc5ae99
import unittest
import numpy as np
import random
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
......@@ -192,6 +193,21 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
class OpTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
'''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate()
np.random.seed(123)
random.seed(124)
@classmethod
def tearDownClass(cls):
'''Restore random seeds'''
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
def check_output_with_place(self, place, atol):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
......
......@@ -43,7 +43,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 2
class_num = 17
class_num = 37
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册