未验证 提交 04325d2c 编写于 作者: X xiongkun 提交者: GitHub

Optest refactor (#40998)

* first version, maybe many errors

* refactor op_test

* fix compare list

* fix bg

* fix bugs

* skip name
上级 45078d9f
......@@ -1398,7 +1398,7 @@ class OpTest(unittest.TestCase):
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_np.size == 0:
self.op_test.assertTrue(actual_np.size == 0) # }}}
self.op_test.assertTrue(actual_np.size == 0)
self._compare_numpy(name, actual_np, expect_np)
if isinstance(expect, tuple):
self._compare_list(name, actual, expect)
......@@ -1486,7 +1486,7 @@ class OpTest(unittest.TestCase):
if actual_np.dtype == np.uint16:
actual_np = convert_uint16_to_float(actual_np)
if expect_np.dtype == np.uint16:
expect_np = convert_uint16_to_float(expect_np) # }}}
expect_np = convert_uint16_to_float(expect_np)
return actual_np, expect_np
def _compare_list(self, name, actual, expect):
......@@ -1519,11 +1519,13 @@ class OpTest(unittest.TestCase):
class EagerChecker(DygraphChecker):
def calculate_output(self):
# we only check end2end api when check_eager=True
self.is_python_api_test = True
with _test_eager_guard():
eager_dygraph_outs = self.op_test._calc_python_api_output(
place)
if eager_dygraph_outs is None:
# missing KernelSignature, fall back to eager middle output.
self.is_python_api_test = False
eager_dygraph_outs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set)
self.outputs = eager_dygraph_outs
......@@ -1547,9 +1549,16 @@ class OpTest(unittest.TestCase):
with _test_eager_guard():
super()._compare_list(name, actual, expect)
# set some flags by the combination of arguments.
def _is_skip_name(self, name):
# if in final state and kernel signature don't have name, then skip it.
if self.is_python_api_test and hasattr(
self.op_test, "python_out_sig"
) and name not in self.op_test.python_out_sig:
return True
return super()._is_skip_name(name)
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) # {{{
# set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0
......@@ -1569,8 +1578,7 @@ class OpTest(unittest.TestCase):
if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError(
"no_check_set of op %s must be set to None." %
self.op_type) # }}}
"no_check_set of op %s must be set to None." % self.op_type)
static_checker = StaticChecker(self, self.outputs)
static_checker.check()
outs, fetch_list = static_checker.outputs, static_checker.fetch_list
......@@ -1610,8 +1618,6 @@ class OpTest(unittest.TestCase):
else:
return outs, fetch_list
# }}}
def check_compile_vs_runtime(self, fetch_list, fetch_outs):
def find_fetch_index(target_name, fetch_list):
found = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册