From 84504f35726a40f31f0e3e00cc95befd5b6f406d Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Thu, 30 Mar 2023 23:54:17 +0800 Subject: [PATCH] support prim & cinn test for layer_norm (#51272) * support layer_norm prim and cinn test * enable cinn test * fix merge conflict * polish input for check_output_with_place * fix merge conflict * add more test case * fix merge conflict * polish test case * polish op_test * change ln_g rules * modify scale is none case * modify scale is none case * add public_python_api for check prim * modify setoutputgrad and fp64bug * add todo & delete log * recover * fix some errors * recover * recover * recover * recover * fix merge conflicts --------- Co-authored-by: wangruting --- .../fluid/tests/unittests/CMakeLists.txt | 3 +- .../fluid/tests/unittests/eager_op_test.py | 11 +- .../tests/unittests/test_layer_norm_op.py | 239 +++++++++++++++++- .../tests/unittests/test_space_to_depth_op.py | 4 +- .../white_list/no_check_set_white_list.py | 1 + .../white_list/op_threshold_white_list.py | 2 + 6 files changed, 252 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f1f6d0add21..bdcd4d006ff 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -982,7 +982,7 @@ else() set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) - set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) + set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 200) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor @@ -1198,6 +1198,7 @@ set(TEST_CINN_OPS test_meshgrid_op test_scatter_op test_gather_op + test_layer_norm_op test_cast_op test_dropout_op test_group_norm_op) diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index f8f6c8023da..061c4763d0c 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -1516,6 +1516,7 @@ class OpTest(unittest.TestCase): self, place, atol=0, + rtol=0, no_check_set=None, equal_nan=False, check_dygraph=True, @@ -1630,7 +1631,7 @@ class OpTest(unittest.TestCase): actual_np, expect_np, atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, err_msg=( "Output (" @@ -1647,7 +1648,7 @@ class OpTest(unittest.TestCase): actual_np, expect_np, atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, ), "Output (" @@ -1819,7 +1820,7 @@ class OpTest(unittest.TestCase): actual_np, expect_np, atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, err_msg=( "Output (" @@ -1836,7 +1837,7 @@ class OpTest(unittest.TestCase): actual_np, expect_np, atol=atol, - rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, + rtol=self.rtol if hasattr(self, 'rtol') else rtol, equal_nan=equal_nan, ), "Output (" @@ -2058,6 +2059,7 @@ class OpTest(unittest.TestCase): def check_output( self, atol=1e-5, + rtol=1e-5, no_check_set=None, equal_nan=False, check_dygraph=True, @@ -2080,6 +2082,7 @@ class OpTest(unittest.TestCase): res = self.check_output_with_place( place, atol, + rtol, no_check_set, equal_nan, check_dygraph=check_dygraph, diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index a43c3ebcd0d..75cabf85b7f 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -17,7 +17,7 @@ from functools import reduce from operator import mul import numpy as np -from eager_op_test import _set_use_system_allocator +from eager_op_test import OpTest, _set_use_system_allocator import paddle import paddle.nn.functional as F @@ -27,7 +27,8 @@ from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 paddle.enable_static() -np.random.random(123) +np.random.seed(123) +paddle.seed(123) _set_use_system_allocator(True) @@ -111,6 +112,240 @@ def _reference_layer_norm_grad( return grad_x, d_scale, d_bias +def layer_norm_wrapper( + x, scale=None, bias=None, epsilon=1e-05, begin_norm_axis=1 +): + input_shape = list(x.shape) + normalized_shape = input_shape[begin_norm_axis:] + return paddle.nn.functional.layer_norm( + x, normalized_shape, weight=scale, bias=bias, epsilon=epsilon + ) + + +class TestLayerNormOpByOpTest(OpTest): + def setUp(self): + self.python_api = layer_norm_wrapper + self.public_python_api = layer_norm_wrapper + self.op_type = "layer_norm" + self.prim_op_type = "comp" + self.python_out_sig = ["Y"] + self.initConfig() + self.initTestCase() + + def test_check_output(self): + self.check_output( + no_check_set=["Mean", "Variance"], + atol=self.ori_atol, + rtol=self.ori_rtol, + check_prim=True, + ) + + def test_check_grad(self): + self.check_grad( + self.check_grad_input_list, + ['Y'], + max_relative_error=self.max_relative_error, + check_prim=True, + ) + + def initConfig(self): + self.rev_comp_atol = 1e-7 + self.rev_comp_rtol = 1e-7 + self.fw_comp_atol = 1e-6 + self.fw_comp_rtol = 1e-6 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.cinn_atol = 1e-5 + self.cinn_rtol = 1e-5 + + self.max_relative_error = 1e-5 + + self.dtype = "float64" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = True + + def initTestCase(self): + np.random.seed(123) + + self.D = reduce( + mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1 + ) + self.scale_shape = [self.D] + x = np.random.random(self.x_shape).astype(self.dtype) + scale = ( + np.random.random(self.scale_shape).astype(self.dtype) + if self.has_scale + else None + ) + bias = ( + np.random.random(self.scale_shape).astype(self.dtype) + if self.has_bias + else None + ) + self.inputs = { + "X": x, + } + self.check_grad_input_list = ['X'] + + if self.has_scale: + self.inputs.update({"Scale": scale}) + self.check_grad_input_list.append('Scale') + if self.has_bias: + self.inputs.update({"Bias": bias}) + self.check_grad_input_list.append('Bias') + + self.attrs = { + "epsilon": self.epsilon, + "begin_norm_axis": self.begin_norm_axis, + } + y, mean, variance = _reference_layer_norm_naive( + x, scale, bias, self.epsilon, self.begin_norm_axis + ) + self.outputs = { + "Y": y, + "Mean": mean, + "Variance": variance, + } + + +class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-6 + self.rev_comp_rtol = 1e-6 + self.fw_comp_atol = 1e-7 + self.fw_comp_rtol = 1e-7 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.cinn_atol = 1e-5 + self.cinn_rtol = 1e-5 + + self.max_relative_error = 1e-5 + + self.dtype = "float64" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = False + + +class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-7 + self.rev_comp_rtol = 1e-7 + self.fw_comp_atol = 1e-7 + self.fw_comp_rtol = 1e-7 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.cinn_atol = 1e-5 + self.cinn_rtol = 1e-5 + + self.max_relative_error = 1e-5 + + self.dtype = "float64" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = False + + +class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-6 + self.rev_comp_rtol = 1e-6 + self.fw_comp_atol = 1e-7 + self.fw_comp_rtol = 1e-7 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.cinn_atol = 1e-5 + self.cinn_rtol = 1e-5 + + self.max_relative_error = 1e-5 + + self.dtype = "float64" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = True + + +class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.max_relative_error = 7e-3 + + self.dtype = "float32" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = True + + +class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.max_relative_error = 1e-5 + + self.dtype = "float32" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = False + + +class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.max_relative_error = 3e-3 + + self.dtype = "float32" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = False + + +class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest): + def initConfig(self): + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + + self.ori_atol = 1e-4 + self.ori_rtol = 1e-4 + self.max_relative_error = 1e-3 + + self.dtype = "float32" + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = True + + class TestLayerNormOp(unittest.TestCase): def setUp(self): self.use_cudnn = True diff --git a/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py index db0c00151ea..4f3af779242 100644 --- a/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py +++ b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py @@ -79,7 +79,9 @@ class TestSpaceToDepthOp(OpTest): if fluid.core.is_compiled_with_cuda() else fluid.core.CPUPlace() ) - self.check_output_with_place(place, 1e-5, None, False) + self.check_output_with_place( + place=place, atol=1e-5, no_check_set=None, equal_nan=False + ) def test_check_grad(self): place = ( diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index baf9e3bf6e6..806b0891ea9 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -38,4 +38,5 @@ no_check_set_white_list = [ 'einsum', 'rmsprop', 'rrelu', + 'layer_norm', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 22bc42d9694..fa151bfb072 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -52,9 +52,11 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'cholesky_solve', 'solve', 'qr', + 'layer_norm', ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = [ 'bilinear_interp', 'bilinear_interp_v2', + 'layer_norm', ] -- GitLab