未验证 提交 edd578a1 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] optest add cinn check test (#52205)

* [CINN] optest add cinn check test

* replace set self.check_cinn to pass check_cinn by function parameter

* fix ci bug

* add cinn atol/rtol
上级 eb93b5c9
...@@ -1115,6 +1115,8 @@ set(TEST_CINN_OPS ...@@ -1115,6 +1115,8 @@ set(TEST_CINN_OPS
test_mean_op test_mean_op
test_unsqueeze2_op test_unsqueeze2_op
test_meshgrid_op test_meshgrid_op
test_scale_op
test_clip_op
test_scatter_op test_scatter_op
test_gather_op test_gather_op
test_layer_norm_op test_layer_norm_op
......
...@@ -494,6 +494,28 @@ class OpTest(unittest.TestCase): ...@@ -494,6 +494,28 @@ class OpTest(unittest.TestCase):
def disable_cal_ref_output(self): def disable_cal_ref_output(self):
self.is_calc_ref = False self.is_calc_ref = False
def _enable_check_cinn_test(self, place, inputs, outputs):
# if the test not run in cuda or the paddle not compile with CINN, skip cinn test
if (
core.is_compiled_with_cinn()
and core.is_compiled_with_cuda()
and isinstance(place, fluid.CUDAPlace)
):
return False
# CINN not support bfloat16 now, skip cinn test
if self.is_bfloat16_op():
return False
# CINN not support 0D-tensor now, skip cinn test
for var in inputs.values():
if len(var.shape()) == 0:
return False
for var in outputs.values():
if len(var.shape) == 0:
return False
# CINN not support dynamic shape now, skip cinn test
# TODO(thisjiang): cannot check dynamic shape op automatic, should do manually now
return True
# set the self.output_dtype . # set the self.output_dtype .
def infer_dtype_from_inputs_outputs(self, inputs, outputs): def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def is_np_data(input): def is_np_data(input):
...@@ -1044,6 +1066,7 @@ class OpTest(unittest.TestCase): ...@@ -1044,6 +1066,7 @@ class OpTest(unittest.TestCase):
loss=None, loss=None,
enable_inplace=None, enable_inplace=None,
for_inplace_test=None, for_inplace_test=None,
check_cinn=False,
): ):
with paddle.fluid.framework._static_guard(): with paddle.fluid.framework._static_guard():
program = Program() program = Program()
...@@ -1087,9 +1110,21 @@ class OpTest(unittest.TestCase): ...@@ -1087,9 +1110,21 @@ class OpTest(unittest.TestCase):
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
fetch_list.append(str(out_name)) fetch_list.append(str(out_name))
if enable_inplace is not None: enable_cinn_test = check_cinn and self._enable_check_cinn_test(
place, feed_map, outputs
)
if enable_cinn_test:
if hasattr(self, 'cinn_atol'):
self.atol = self.cinn_atol
if hasattr(self, 'cinn_rtol'):
self.rtol = self.cinn_rtol
if (enable_inplace is not None) or enable_cinn_test:
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace if enable_inplace is not None:
build_strategy.enable_inplace = enable_inplace
if enable_cinn_test:
build_strategy.build_cinn_pass = check_cinn
compiled_prog = fluid.CompiledProgram( compiled_prog = fluid.CompiledProgram(
program, build_strategy=build_strategy program, build_strategy=build_strategy
...@@ -1518,6 +1553,7 @@ class OpTest(unittest.TestCase): ...@@ -1518,6 +1553,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
inplace_atol=None, inplace_atol=None,
check_cinn=False,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
...@@ -1626,7 +1662,7 @@ class OpTest(unittest.TestCase): ...@@ -1626,7 +1662,7 @@ class OpTest(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=self.atol if hasattr(self, 'atol') else atol,
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
err_msg=( err_msg=(
...@@ -1645,7 +1681,7 @@ class OpTest(unittest.TestCase): ...@@ -1645,7 +1681,7 @@ class OpTest(unittest.TestCase):
np.allclose( np.allclose(
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=self.atol if hasattr(self, 'atol') else atol,
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
), ),
...@@ -1721,7 +1757,7 @@ class OpTest(unittest.TestCase): ...@@ -1721,7 +1757,7 @@ class OpTest(unittest.TestCase):
def calculate_output(self): def calculate_output(self):
outs, fetch_list = self.op_test._calc_output( outs, fetch_list = self.op_test._calc_output(
place, no_check_set=no_check_set place, no_check_set=no_check_set, check_cinn=check_cinn
) )
self.outputs = outs self.outputs = outs
self.fetch_list = fetch_list self.fetch_list = fetch_list
...@@ -2077,6 +2113,7 @@ class OpTest(unittest.TestCase): ...@@ -2077,6 +2113,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
inplace_atol=None, inplace_atol=None,
check_cinn=False,
): ):
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
...@@ -2100,6 +2137,7 @@ class OpTest(unittest.TestCase): ...@@ -2100,6 +2137,7 @@ class OpTest(unittest.TestCase):
check_dygraph=check_dygraph, check_dygraph=check_dygraph,
check_prim=check_prim, check_prim=check_prim,
inplace_atol=inplace_atol, inplace_atol=inplace_atol,
check_cinn=check_cinn,
) )
if check_dygraph: if check_dygraph:
outs, dygraph_dygraph_outs, fetch_list = res outs, dygraph_dygraph_outs, fetch_list = res
...@@ -2257,6 +2295,7 @@ class OpTest(unittest.TestCase): ...@@ -2257,6 +2295,7 @@ class OpTest(unittest.TestCase):
check_prim=False, check_prim=False,
only_check_prim=False, only_check_prim=False,
atol=1e-5, atol=1e-5,
check_cinn=False,
): ):
if hasattr(self, "use_custom_device") and self.use_custom_device: if hasattr(self, "use_custom_device") and self.use_custom_device:
check_dygraph = False check_dygraph = False
...@@ -2278,6 +2317,7 @@ class OpTest(unittest.TestCase): ...@@ -2278,6 +2317,7 @@ class OpTest(unittest.TestCase):
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim, only_check_prim=only_check_prim,
atol=atol, atol=atol,
check_cinn=check_cinn,
) )
def check_grad_with_place( def check_grad_with_place(
...@@ -2296,6 +2336,7 @@ class OpTest(unittest.TestCase): ...@@ -2296,6 +2336,7 @@ class OpTest(unittest.TestCase):
only_check_prim=False, only_check_prim=False,
numeric_place=None, numeric_place=None,
atol=1e-5, atol=1e-5,
check_cinn=False,
): ):
if hasattr(self, "use_custom_device") and self.use_custom_device: if hasattr(self, "use_custom_device") and self.use_custom_device:
check_dygraph = False check_dygraph = False
...@@ -2427,6 +2468,7 @@ class OpTest(unittest.TestCase): ...@@ -2427,6 +2468,7 @@ class OpTest(unittest.TestCase):
output_names, output_names,
no_grad_set, no_grad_set,
user_defined_grad_outputs, user_defined_grad_outputs,
check_cinn=check_cinn,
) )
# comparison of bf16 results will happen as fp32 # comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32 # loop over list of grads and convert bf16 to fp32
...@@ -2655,6 +2697,7 @@ class OpTest(unittest.TestCase): ...@@ -2655,6 +2697,7 @@ class OpTest(unittest.TestCase):
no_grad_set, no_grad_set,
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
parallel=False, parallel=False,
check_cinn=False,
): ):
with paddle.fluid.framework._static_guard(): with paddle.fluid.framework._static_guard():
prog = Program() prog = Program()
...@@ -2721,11 +2764,26 @@ class OpTest(unittest.TestCase): ...@@ -2721,11 +2764,26 @@ class OpTest(unittest.TestCase):
) )
fetch_list = grad_inputs fetch_list = grad_inputs
if parallel: enable_cinn_test = check_cinn and self._enable_check_cinn_test(
place, feed_dict, outputs
)
if enable_cinn_test:
if hasattr(self, 'cinn_atol'):
self.atol = self.cinn_atol
if hasattr(self, 'cinn_rtol'):
self.rtol = self.cinn_rtol
if parallel or enable_cinn_test:
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
use_cuda = True use_cuda = True
compiled_prog = fluid.CompiledProgram(prog)
build_strategy = None
if enable_cinn_test:
build_strategy = fluid.BuildStrategy()
build_strategy.build_cinn_pass = check_cinn
compiled_prog = fluid.CompiledProgram(prog, build_strategy)
prog = compiled_prog prog = compiled_prog
executor = fluid.Executor(place) executor = fluid.Executor(place)
res = list( res = list(
......
...@@ -49,10 +49,13 @@ class TestClipOp(OpTest): ...@@ -49,10 +49,13 @@ class TestClipOp(OpTest):
input[np.abs(input - max_v) < self.max_relative_error] = 0.5 input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input self.inputs['X'] = input
self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)} self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)}
self.check_cinn = ('Min' not in self.inputs) and (
'Max' not in self.inputs
)
def test_check_output(self): def test_check_output(self):
paddle.enable_static() paddle.enable_static()
self.check_output() self.check_output(check_cinn=self.check_cinn)
paddle.disable_static() paddle.disable_static()
def test_check_grad_normal(self): def test_check_grad_normal(self):
......
...@@ -42,7 +42,7 @@ class TestScaleOp(OpTest): ...@@ -42,7 +42,7 @@ class TestScaleOp(OpTest):
pass pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
...@@ -66,7 +66,7 @@ class TestScaleOpScaleVariable(OpTest): ...@@ -66,7 +66,7 @@ class TestScaleOpScaleVariable(OpTest):
pass pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
...@@ -148,7 +148,7 @@ class TestScaleFp16Op(TestScaleOp): ...@@ -148,7 +148,7 @@ class TestScaleFp16Op(TestScaleOp):
self.dtype = np.float16 self.dtype = np.float16
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册