From 04325d2cbefb029a4478bdc069d3279cd566ac6a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 30 Mar 2022 11:24:31 +0800 Subject: [PATCH] Optest refactor (#40998) * first version, maybe many errors * refactor op_test * fix compare list * fix bg * fix bugs * skip name --- .../paddle/fluid/tests/unittests/op_test.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index a0c8723323..4368ef69f4 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1398,7 +1398,7 @@ class OpTest(unittest.TestCase): # NOTE(zhiqiu): np.allclose([], [1.]) returns True # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng if expect_np.size == 0: - self.op_test.assertTrue(actual_np.size == 0) # }}} + self.op_test.assertTrue(actual_np.size == 0) self._compare_numpy(name, actual_np, expect_np) if isinstance(expect, tuple): self._compare_list(name, actual, expect) @@ -1486,7 +1486,7 @@ class OpTest(unittest.TestCase): if actual_np.dtype == np.uint16: actual_np = convert_uint16_to_float(actual_np) if expect_np.dtype == np.uint16: - expect_np = convert_uint16_to_float(expect_np) # }}} + expect_np = convert_uint16_to_float(expect_np) return actual_np, expect_np def _compare_list(self, name, actual, expect): @@ -1519,11 +1519,13 @@ class OpTest(unittest.TestCase): class EagerChecker(DygraphChecker): def calculate_output(self): # we only check end2end api when check_eager=True + self.is_python_api_test = True with _test_eager_guard(): eager_dygraph_outs = self.op_test._calc_python_api_output( place) if eager_dygraph_outs is None: # missing KernelSignature, fall back to eager middle output. + self.is_python_api_test = False eager_dygraph_outs = self.op_test._calc_dygraph_output( place, no_check_set=no_check_set) self.outputs = eager_dygraph_outs @@ -1547,9 +1549,16 @@ class OpTest(unittest.TestCase): with _test_eager_guard(): super()._compare_list(name, actual, expect) -# set some flags by the combination of arguments. + def _is_skip_name(self, name): + # if in final state and kernel signature don't have name, then skip it. + if self.is_python_api_test and hasattr( + self.op_test, "python_out_sig" + ) and name not in self.op_test.python_out_sig: + return True + return super()._is_skip_name(name) - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) # {{{ + # set some flags by the combination of arguments. + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if self.dtype == np.float64 and \ self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: atol = 0 @@ -1569,8 +1578,7 @@ class OpTest(unittest.TestCase): if no_check_set is not None: if self.op_type not in no_check_set_white_list.no_check_set_white_list: raise AssertionError( - "no_check_set of op %s must be set to None." % - self.op_type) # }}} + "no_check_set of op %s must be set to None." % self.op_type) static_checker = StaticChecker(self, self.outputs) static_checker.check() outs, fetch_list = static_checker.outputs, static_checker.fetch_list @@ -1610,8 +1618,6 @@ class OpTest(unittest.TestCase): else: return outs, fetch_list -# }}} - def check_compile_vs_runtime(self, fetch_list, fetch_outs): def find_fetch_index(target_name, fetch_list): found = [ -- GitLab