From e397a3ff88d19d89b265449ebcccf74dd0d43fb2 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 9 Mar 2023 15:27:25 +0800 Subject: [PATCH] [AMP_OP&Test] Add float16 OpTest for full_op (#50723) --- .../tests/unittests/test_fill_constant_op.py | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 3151744aa4c..8f8703093b3 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -30,60 +30,71 @@ def fill_wrapper(shape, value=0.0): # Situation 1: Attr(shape) is a list(without tensor) -class TestFillConstantOp1(OpTest): +# Base case +class TestFillConstantOp(OpTest): def setUp(self): - '''Test fill_constant op with specified value''' + '''Test fill_constant op with default value''' self.op_type = "fill_constant" self.python_api = fill_wrapper + self.init_dtype() + self.init_shape() + self.init_value() self.inputs = {} - self.attrs = {'shape': [123, 92], 'value': 3.8} - self.outputs = {'Out': np.full((123, 92), 3.8)} + self.attrs = {'shape': self.shape, 'value': self.value} + self.outputs = {'Out': np.full(self.shape, self.value)} def test_check_output(self): self.check_output() + def init_dtype(self): + self.dtype = np.float64 -class TestFillConstantOp2(OpTest): - def setUp(self): - '''Test fill_constant op with default value''' - self.op_type = "fill_constant" - self.python_api = fill_wrapper + def init_shape(self): + self.shape = [123, 92] - self.inputs = {} - self.attrs = {'shape': [123, 92]} - self.outputs = {'Out': np.full((123, 92), 0.0)} + def init_value(self): + self.value = 0.0 - def test_check_output(self): - self.check_output() +class TestFillConstantFP32Op(TestFillConstantOp): + '''Test fill_constant op with specified value''' -class TestFillConstantOp3(OpTest): - def setUp(self): - '''Test fill_constant op with specified int64 value''' - self.op_type = "fill_constant" - self.python_api = fill_wrapper + def init_dtype(self): + self.dtype = np.float32 - self.inputs = {} - self.attrs = {'shape': [123, 92], 'value': 10000000000} - self.outputs = {'Out': np.full((123, 92), 10000000000)} + def init_value(self): + self.value = 3.8 - def test_check_output(self): - self.check_output() +class TestFillConstantFP16Op(TestFillConstantOp): + '''Test fill_constant op with specified value''' -class TestFillConstantOp4(OpTest): - def setUp(self): - '''Test fill_constant op with specified int value''' - self.op_type = "fill_constant" - self.python_api = fill_wrapper + def init_dtype(self): + self.dtype = np.float16 - self.inputs = {} - self.attrs = {'shape': [123, 92], 'value': 3} - self.outputs = {'Out': np.full((123, 92), 3)} + def init_value(self): + self.value = 3.8 - def test_check_output(self): - self.check_output() + +class TestFillConstantINT64Op(TestFillConstantOp): + '''Test fill_constant op with specified int64 value''' + + def init_dtype(self): + self.dtype = np.int64 + + def init_value(self): + self.value = 10000000000 + + +class TestFillConstantINT32Op(TestFillConstantOp): + '''Test fill_constant op with specified int value''' + + def init_dtype(self): + self.dtype = np.int32 + + def init_value(self): + self.value = 3 @unittest.skipIf( -- GitLab