未验证 提交 04325d2c 编写于 作者: X xiongkun 提交者: GitHub

Optest refactor (#40998)

* first version, maybe many errors

* refactor op_test

* fix compare list

* fix bg

* fix bugs

* skip name
上级 45078d9f
...@@ -1398,7 +1398,7 @@ class OpTest(unittest.TestCase): ...@@ -1398,7 +1398,7 @@ class OpTest(unittest.TestCase):
# NOTE(zhiqiu): np.allclose([], [1.]) returns True # NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_np.size == 0: if expect_np.size == 0:
self.op_test.assertTrue(actual_np.size == 0) # }}} self.op_test.assertTrue(actual_np.size == 0)
self._compare_numpy(name, actual_np, expect_np) self._compare_numpy(name, actual_np, expect_np)
if isinstance(expect, tuple): if isinstance(expect, tuple):
self._compare_list(name, actual, expect) self._compare_list(name, actual, expect)
...@@ -1486,7 +1486,7 @@ class OpTest(unittest.TestCase): ...@@ -1486,7 +1486,7 @@ class OpTest(unittest.TestCase):
if actual_np.dtype == np.uint16: if actual_np.dtype == np.uint16:
actual_np = convert_uint16_to_float(actual_np) actual_np = convert_uint16_to_float(actual_np)
if expect_np.dtype == np.uint16: if expect_np.dtype == np.uint16:
expect_np = convert_uint16_to_float(expect_np) # }}} expect_np = convert_uint16_to_float(expect_np)
return actual_np, expect_np return actual_np, expect_np
def _compare_list(self, name, actual, expect): def _compare_list(self, name, actual, expect):
...@@ -1519,11 +1519,13 @@ class OpTest(unittest.TestCase): ...@@ -1519,11 +1519,13 @@ class OpTest(unittest.TestCase):
class EagerChecker(DygraphChecker): class EagerChecker(DygraphChecker):
def calculate_output(self): def calculate_output(self):
# we only check end2end api when check_eager=True # we only check end2end api when check_eager=True
self.is_python_api_test = True
with _test_eager_guard(): with _test_eager_guard():
eager_dygraph_outs = self.op_test._calc_python_api_output( eager_dygraph_outs = self.op_test._calc_python_api_output(
place) place)
if eager_dygraph_outs is None: if eager_dygraph_outs is None:
# missing KernelSignature, fall back to eager middle output. # missing KernelSignature, fall back to eager middle output.
self.is_python_api_test = False
eager_dygraph_outs = self.op_test._calc_dygraph_output( eager_dygraph_outs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set) place, no_check_set=no_check_set)
self.outputs = eager_dygraph_outs self.outputs = eager_dygraph_outs
...@@ -1547,9 +1549,16 @@ class OpTest(unittest.TestCase): ...@@ -1547,9 +1549,16 @@ class OpTest(unittest.TestCase):
with _test_eager_guard(): with _test_eager_guard():
super()._compare_list(name, actual, expect) super()._compare_list(name, actual, expect)
# set some flags by the combination of arguments. def _is_skip_name(self, name):
# if in final state and kernel signature don't have name, then skip it.
if self.is_python_api_test and hasattr(
self.op_test, "python_out_sig"
) and name not in self.op_test.python_out_sig:
return True
return super()._is_skip_name(name)
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) # {{{ # set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.dtype == np.float64 and \ if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0 atol = 0
...@@ -1569,8 +1578,7 @@ class OpTest(unittest.TestCase): ...@@ -1569,8 +1578,7 @@ class OpTest(unittest.TestCase):
if no_check_set is not None: if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list: if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError( raise AssertionError(
"no_check_set of op %s must be set to None." % "no_check_set of op %s must be set to None." % self.op_type)
self.op_type) # }}}
static_checker = StaticChecker(self, self.outputs) static_checker = StaticChecker(self, self.outputs)
static_checker.check() static_checker.check()
outs, fetch_list = static_checker.outputs, static_checker.fetch_list outs, fetch_list = static_checker.outputs, static_checker.fetch_list
...@@ -1610,8 +1618,6 @@ class OpTest(unittest.TestCase): ...@@ -1610,8 +1618,6 @@ class OpTest(unittest.TestCase):
else: else:
return outs, fetch_list return outs, fetch_list
# }}}
def check_compile_vs_runtime(self, fetch_list, fetch_outs): def check_compile_vs_runtime(self, fetch_list, fetch_outs):
def find_fetch_index(target_name, fetch_list): def find_fetch_index(target_name, fetch_list):
found = [ found = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册