From 509c8399b8dbf1491fd6adc55b3c423e2d3501be Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 19 Mar 2018 20:16:16 -0700 Subject: [PATCH] address comments --- .../fluid/tests/unittests/test_dropout_op.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 2939895d79b..eaa3435a864 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -83,36 +83,36 @@ class TestDropoutOp5(OpTest): self.check_output() -class TestFP16DropoutOp1(OpTest): +class TestFP16DropoutOp(OpTest): def setUp(self): - x = np.random.random((32, 64)).astype("float16") - prob = 0.35 - out = x * (1.0 - prob) - self.op_type = "dropout" + self.init_test_case() + + x = np.random.random(self.input_size).astype("float16") + out = x * (1.0 - self.prob) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': prob, 'fix_seed': True, 'is_test': True} + self.attrs = { + 'dropout_prob': self.prob, + 'fix_seed': self.fix_seed, + 'is_test': True + } self.outputs = {'Out': out} + def init_test_case(self): + self.input_size = [32, 64] + self.prob = 0.35 + self.fix_seed = True + def test_check_output(self): if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) -class TestFP16DropoutOp2(OpTest): - def setUp(self): - x = np.random.random((32, 64, 3)).astype("float16") - prob = 0.75 - out = x * (1.0 - prob) - - self.op_type = "dropout" - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.attrs = {'dropout_prob': prob, 'is_test': True} - self.outputs = {'Out': out} - - def test_check_output(self): - if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): - self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) +class TestFP16DropoutOp2(TestFP16DropoutOp): + def init_test_case(self): + self.input_size = [32, 64, 3] + self.prob = 0.75 + self.fix_seed = False if __name__ == '__main__': -- GitLab