未验证 提交 37f914c8 编写于 作者: X xiongkun 提交者: GitHub

[ Optest ] refactor optest check_output_with_place logic (#40928)

* first version, maybe many errors

* refactor op_test

* fix compare list

* fix bg

* fix bugs
上级 afa0e82c
...@@ -1317,7 +1317,239 @@ class OpTest(unittest.TestCase): ...@@ -1317,7 +1317,239 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
inplace_atol=None, inplace_atol=None,
check_eager=False): check_eager=False):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) def find_imperative_actual(target_name, dygraph_outs, place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if var.name == target_name:
return dygraph_outs[name][i]
self.assertTrue(False, "Found failed {} {}".format(
dygraph_outs.keys(), target_name))
def find_actual(target_name, fetch_list):
found = [
i for i, var_name in enumerate(fetch_list)
if var_name == target_name
]
self.assertTrue(
len(found) == 1, "Found {} {}".format(len(found), target_name))
return found[0]
class Checker(object):
""" base class for check with self.outputs.
currently don't support check between checkers.
"""
def __init__(self, op_test, expect_dict):
""" expect_dict is the self.outputs
support : {str: [numpy]} and {str: [(str, numpy), (str, numpy)]}
"""
self.expects = expect_dict
self.checker_name = "checker"
self.op_test = op_test # stop the op_test object.
self.op_type = op_test.op_type
def convert_uint16_to_float(self, actual_np, expect_np):
raise NotImplementedError("base class, not implement!")
def calculate_output(self):
"""
judge whether convert current output and expect to uint16.
return True | False
"""
pass
def _is_skip_name(self, name):
if name not in self.expects:
return True
if no_check_set is not None and name in no_check_set:
return True
return False
def find_actual_value(self, name):
""" return: (actual_tensor(var_base), actual_numpy)
"""
raise NotImplementedError("base class, not implement!")
def _compare_numpy(self, name, actual_np, expect_np):
self.op_test.assertTrue(
np.allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan),
"Output (" + name + ") has diff at " + str(place) + " in " +
self.checker_name + " checker")
def _compare_list(self, name, actual, expect):
""" if expect is a tuple, we need to compare list.
"""
raise NotImplementedError("base class, not implement!")
def compare_single_output_with_expect(self, name, expect):
actual, actual_np = self.find_actual_value(name)
expect_np = expect[0] \
if isinstance(expect, tuple) else expect
actual_np, expect_np = self.convert_uint16_to_float_ifneed(
actual_np, expect_np)
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_np.size == 0:
self.op_test.assertTrue(actual_np.size == 0) # }}}
self._compare_numpy(name, actual_np, expect_np)
if isinstance(expect, tuple):
self._compare_list(name, actual, expect)
def compare_outputs_with_expects(self):
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if self._is_skip_name(out_name): continue
if out_dup:
# if self.output = {'name': [(subname, Tensor), (subname, Tensor)]}
sub_out = self.expects[out_name]
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
self.compare_single_output_with_expect(sub_out_name,
expect)
else:
expect = self.expects[out_name]
self.compare_single_output_with_expect(out_name, expect)
def check(self):
"""
return None means ok, raise Error means failed.
the main enter point of Checker class
"""
self.calculate_output()
self.compare_outputs_with_expects()
class StaticChecker(Checker):
def calculate_output(self):
outs, fetch_list = self.op_test._calc_output(
place, no_check_set=no_check_set)
self.outputs = outs
self.fetch_list = fetch_list
def find_actual_value(self, name):
idx = find_actual(name, self.fetch_list)
actual = self.outputs[idx]
actual_t = np.array(actual)
return actual, actual_t
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
"""
judge whether convert current output and expect to uint16.
return True | False
"""
if actual_np.dtype == np.uint16 and expect_np.dtype in [
np.float32, np.float64
]:
actual_np = convert_uint16_to_float(actual_np)
self.rtol = 1.e-2
else:
self.rtol = 1.e-5
if expect_np.dtype == np.uint16 and actual_np.dtype == np.uint16:
nonlocal atol
expect_np = convert_uint16_to_float(expect_np)
actual_np = convert_uint16_to_float(actual_np)
atol = max(atol, 0.03)
return actual_np, expect_np
def _compare_list(self, name, actual, expect):
""" if expect is a tuple, we need to compare list.
"""
self.op_test.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + name + ") has different lod at " + str(place))
class DygraphChecker(Checker):
def calculate_output(self):
self.outputs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set)
def find_actual_value(self, name):
with fluid.dygraph.base.guard(place=place):
imperative_actual = find_imperative_actual(
name, self.outputs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
return imperative_actual, imperative_actual_t
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if self.op_test.is_bfloat16_op():
if actual_np.dtype == np.uint16:
actual_np = convert_uint16_to_float(actual_np)
if expect_np.dtype == np.uint16:
expect_np = convert_uint16_to_float(expect_np) # }}}
return actual_np, expect_np
def _compare_list(self, name, actual, expect):
""" if expect is a tuple, we need to compare list.
"""
with fluid.dygraph.base.guard(place=place):
self.op_test.assertListEqual(
actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + name + ") has different lod at " +
str(place) + " in dygraph mode")
def _compare_numpy(self, name, actual_np, expect_np):
if six.moves.reduce(lambda x, y: x * y, actual_np.shape,
1) == 0 and six.moves.reduce(
lambda x, y: x * y, expect_np.shape,
1) == 0:
pass
else:
self.op_test.assertTrue(
np.allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan),
"Output (" + name + ") has diff at " + str(place) +
" in " + self.checker_name + " checker")
class EagerChecker(DygraphChecker):
def calculate_output(self):
# we only check end2end api when check_eager=True
with _test_eager_guard():
eager_dygraph_outs = self.op_test._calc_python_api_output(
place)
if eager_dygraph_outs is None:
# missing KernelSignature, fall back to eager middle output.
eager_dygraph_outs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set)
self.outputs = eager_dygraph_outs
def _compare_numpy(self, name, actual_np, expect_np):
with _test_eager_guard():
super()._compare_numpy(name, actual_np, expect_np)
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
with _test_eager_guard():
return super().convert_uint16_to_float_ifneed(actual_np,
expect_np)
def find_actual_value(self, name):
with _test_eager_guard():
return super().find_actual_value(name)
def _compare_list(self, name, actual, expect):
""" if expect is a tuple, we need to compare list.
"""
with _test_eager_guard():
super()._compare_list(name, actual, expect)
# set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) # {{{
if self.dtype == np.float64 and \ if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0 atol = 0
...@@ -1337,260 +1569,19 @@ class OpTest(unittest.TestCase): ...@@ -1337,260 +1569,19 @@ class OpTest(unittest.TestCase):
if no_check_set is not None: if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list: if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError( raise AssertionError(
"no_check_set of op %s must be set to None." % self.op_type) "no_check_set of op %s must be set to None." %
self.op_type) # }}}
static_checker = StaticChecker(self, self.outputs)
static_checker.check()
outs, fetch_list = static_checker.outputs, static_checker.fetch_list
if check_dygraph: if check_dygraph:
if _in_eager_without_dygraph_check(): dygraph_checker = DygraphChecker(self, self.outputs)
_enable_legacy_dygraph() dygraph_checker.check()
dygraph_outs = self._calc_dygraph_output( dygraph_outs = dygraph_checker.outputs
place, no_check_set=no_check_set)
_disable_legacy_dygraph()
else:
dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
if check_eager: if check_eager:
# we only check end2end api when check_eager=True eager_checker = EagerChecker(self, self.outputs)
with fluid.dygraph.base.guard(place): eager_checker.check()
with _test_eager_guard(): eager_dygraph_outs = eager_checker.outputs
eager_dygraph_outs = self._calc_python_api_output(place)
if eager_dygraph_outs is None:
# missing KernelSignature, fall back to eager middle output.
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
if no_check_set is not None and out_name in no_check_set:
continue
def find_imperative_actual(target_name, dygraph_outs, place):
with fluid.dygraph.base.guard(place=place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if var.name == target_name:
return dygraph_outs[name][i]
self.assertTrue(False, "Found failed {} {}".format(
dygraph_outs.keys(), target_name))
def find_actual(target_name, fetch_list):
found = [
i for i, var_name in enumerate(fetch_list)
if var_name == target_name
]
self.assertTrue(
len(found) == 1, "Found {} {}".format(
len(found), target_name))
return found[0]
if out_dup:
sub_out = self.outputs[out_name]
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for item in sub_out:
sub_out_name, expect = item[0], item[1]
if check_dygraph:
if _in_eager_without_dygraph_check():
_enable_legacy_dygraph()
imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array(
imperative_actual.value().get_tensor())
_disable_legacy_dygraph()
else:
imperative_actual = find_imperative_actual(
sub_out_name, dygraph_outs, place)
imperative_actual_t = np.array(
imperative_actual.value().get_tensor())
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
eager_imperative_actual = find_imperative_actual(
sub_out_name, eager_dygraph_outs, place)
eager_imperative_actual_t = eager_imperative_actual.numpy(
)
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if check_dygraph:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " +
str(place) + " in dygraph mode")
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
self.assertTrue(
np.allclose(
eager_imperative_actual_t,
expect_t,
atol=atol,
equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at "
+ str(place) + " in eager dygraph mode")
if isinstance(expect, tuple):
self.assertListEqual(
actual.recursive_sequence_lengths(), expect[1],
"Output (" + sub_out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place) +
" in dygraph mode")
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
self.assertListEqual(
eager_imperative_actual.value(
).get_tensor()
.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place) +
" in eager dygraph mode")
else:
if check_dygraph:
with fluid.dygraph.base.guard(place=place):
imperative_actual = find_imperative_actual(
out_name, dygraph_outs, place)
imperative_actual_t = np.array(imperative_actual.value()
.get_tensor())
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
eager_imperative_actual = find_imperative_actual(
out_name, eager_dygraph_outs, place)
eager_imperative_actual_t = eager_imperative_actual.numpy(
)
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
# np.uint16 represents bfloat16
if actual_t.dtype == np.uint16 and expect_t.dtype in [
np.float32, np.float64
]:
actual_t = convert_uint16_to_float(actual_t)
rtol = 1.e-2
else:
rtol = 1.e-5
if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
actual_t = convert_uint16_to_float(actual_t)
atol = max(atol, 0.03)
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_t.size == 0:
self.assertTrue(actual_t.size == 0)
self.assertTrue(
np.allclose(
actual_t,
expect_t,
atol=atol,
rtol=rtol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t) + " in class " + self.__class__.__name__)
if check_dygraph:
if self.is_bfloat16_op():
if imperative_actual_t.dtype == np.uint16:
imperative_actual_t = convert_uint16_to_float(
imperative_actual_t)
if expect_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
if six.moves.reduce(
lambda x, y: x * y, imperative_actual_t.shape,
1) == 0 and six.moves.reduce(
lambda x, y: x * y, expect_t.shape, 1) == 0:
pass
else:
self.assertTrue(
np.allclose(
imperative_actual_t,
expect_t,
atol=atol,
rtol=rtol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " +
str(place) + "\nExpect " + str(expect_t) + "\n" +
"But Got" + str(imperative_actual_t) + " in class "
+ self.__class__.__name__)
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
if self.is_bfloat16_op():
if eager_imperative_actual_t.dtype == np.uint16:
eager_imperative_actual_t = convert_uint16_to_float(
eager_imperative_actual_t)
if expect_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
if six.moves.reduce(lambda x, y: x * y,
eager_imperative_actual_t.shape,
1) == 0 and six.moves.reduce(
lambda x, y: x * y,
expect_t.shape, 1) == 0:
pass
else:
self.assertTrue(
np.allclose(
eager_imperative_actual_t,
expect_t,
atol=atol,
rtol=rtol,
equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " +
str(place) + "\nExpect " + str(expect_t) +
"\n" + "But Got" +
str(eager_imperative_actual_t) +
" in class " + self.__class__.__name__)
if isinstance(expect, tuple):
self.assertListEqual(actual.recursive_sequence_lengths(),
expect[1], "Output (" + out_name +
") has different lod at " + str(place))
if check_dygraph:
self.assertListEqual(
imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name + ") has different lod at " +
str(place) + " in eager dygraph mode")
if check_eager:
with fluid.dygraph.base.guard(place):
with _test_eager_guard():
self.assertListEqual(
eager_imperative_actual.value().get_tensor()
.recursive_sequence_lengths(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place) +
" in eager dygraph mode")
# Note(zhiqiu): inplace_atol should be only set when op doesn't ensure # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
# computational consistency. # computational consistency.
...@@ -1619,6 +1610,8 @@ class OpTest(unittest.TestCase): ...@@ -1619,6 +1610,8 @@ class OpTest(unittest.TestCase):
else: else:
return outs, fetch_list return outs, fetch_list
# }}}
def check_compile_vs_runtime(self, fetch_list, fetch_outs): def check_compile_vs_runtime(self, fetch_list, fetch_outs):
def find_fetch_index(target_name, fetch_list): def find_fetch_index(target_name, fetch_list):
found = [ found = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册