未验证 提交 84504f35 编写于 作者: W Weilong Wu 提交者: GitHub

support prim & cinn test for layer_norm (#51272)

* support layer_norm prim and cinn test

* enable cinn test

* fix merge conflict

* polish input for check_output_with_place

* fix merge conflict

* add more test case

* fix merge conflict

* polish test case

* polish op_test

* change ln_g rules

* modify scale is none case

* modify scale is none case

* add public_python_api for check prim

* modify setoutputgrad and fp64bug

* add todo & delete log

* recover

* fix some errors

* recover

* recover

* recover

* recover

* fix merge conflicts

---------
Co-authored-by: Nwangruting <wangruting@baidu.com>
上级 a4e0f666
...@@ -982,7 +982,7 @@ else() ...@@ -982,7 +982,7 @@ else()
set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
endif() endif()
set_tests_properties(test_imperative_selected_rows_to_lod_tensor set_tests_properties(test_imperative_selected_rows_to_lod_tensor
...@@ -1198,6 +1198,7 @@ set(TEST_CINN_OPS ...@@ -1198,6 +1198,7 @@ set(TEST_CINN_OPS
test_meshgrid_op test_meshgrid_op
test_scatter_op test_scatter_op
test_gather_op test_gather_op
test_layer_norm_op
test_cast_op test_cast_op
test_dropout_op test_dropout_op
test_group_norm_op) test_group_norm_op)
......
...@@ -1516,6 +1516,7 @@ class OpTest(unittest.TestCase): ...@@ -1516,6 +1516,7 @@ class OpTest(unittest.TestCase):
self, self,
place, place,
atol=0, atol=0,
rtol=0,
no_check_set=None, no_check_set=None,
equal_nan=False, equal_nan=False,
check_dygraph=True, check_dygraph=True,
...@@ -1630,7 +1631,7 @@ class OpTest(unittest.TestCase): ...@@ -1630,7 +1631,7 @@ class OpTest(unittest.TestCase):
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
err_msg=( err_msg=(
"Output (" "Output ("
...@@ -1647,7 +1648,7 @@ class OpTest(unittest.TestCase): ...@@ -1647,7 +1648,7 @@ class OpTest(unittest.TestCase):
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
), ),
"Output (" "Output ("
...@@ -1819,7 +1820,7 @@ class OpTest(unittest.TestCase): ...@@ -1819,7 +1820,7 @@ class OpTest(unittest.TestCase):
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
err_msg=( err_msg=(
"Output (" "Output ("
...@@ -1836,7 +1837,7 @@ class OpTest(unittest.TestCase): ...@@ -1836,7 +1837,7 @@ class OpTest(unittest.TestCase):
actual_np, actual_np,
expect_np, expect_np,
atol=atol, atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
), ),
"Output (" "Output ("
...@@ -2058,6 +2059,7 @@ class OpTest(unittest.TestCase): ...@@ -2058,6 +2059,7 @@ class OpTest(unittest.TestCase):
def check_output( def check_output(
self, self,
atol=1e-5, atol=1e-5,
rtol=1e-5,
no_check_set=None, no_check_set=None,
equal_nan=False, equal_nan=False,
check_dygraph=True, check_dygraph=True,
...@@ -2080,6 +2082,7 @@ class OpTest(unittest.TestCase): ...@@ -2080,6 +2082,7 @@ class OpTest(unittest.TestCase):
res = self.check_output_with_place( res = self.check_output_with_place(
place, place,
atol, atol,
rtol,
no_check_set, no_check_set,
equal_nan, equal_nan,
check_dygraph=check_dygraph, check_dygraph=check_dygraph,
......
...@@ -17,7 +17,7 @@ from functools import reduce ...@@ -17,7 +17,7 @@ from functools import reduce
from operator import mul from operator import mul
import numpy as np import numpy as np
from eager_op_test import _set_use_system_allocator from eager_op_test import OpTest, _set_use_system_allocator
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -27,7 +27,8 @@ from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 ...@@ -27,7 +27,8 @@ from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
paddle.enable_static() paddle.enable_static()
np.random.random(123) np.random.seed(123)
paddle.seed(123)
_set_use_system_allocator(True) _set_use_system_allocator(True)
...@@ -111,6 +112,240 @@ def _reference_layer_norm_grad( ...@@ -111,6 +112,240 @@ def _reference_layer_norm_grad(
return grad_x, d_scale, d_bias return grad_x, d_scale, d_bias
def layer_norm_wrapper(
x, scale=None, bias=None, epsilon=1e-05, begin_norm_axis=1
):
input_shape = list(x.shape)
normalized_shape = input_shape[begin_norm_axis:]
return paddle.nn.functional.layer_norm(
x, normalized_shape, weight=scale, bias=bias, epsilon=epsilon
)
class TestLayerNormOpByOpTest(OpTest):
def setUp(self):
self.python_api = layer_norm_wrapper
self.public_python_api = layer_norm_wrapper
self.op_type = "layer_norm"
self.prim_op_type = "comp"
self.python_out_sig = ["Y"]
self.initConfig()
self.initTestCase()
def test_check_output(self):
self.check_output(
no_check_set=["Mean", "Variance"],
atol=self.ori_atol,
rtol=self.ori_rtol,
check_prim=True,
)
def test_check_grad(self):
self.check_grad(
self.check_grad_input_list,
['Y'],
max_relative_error=self.max_relative_error,
check_prim=True,
)
def initConfig(self):
self.rev_comp_atol = 1e-7
self.rev_comp_rtol = 1e-7
self.fw_comp_atol = 1e-6
self.fw_comp_rtol = 1e-6
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.cinn_atol = 1e-5
self.cinn_rtol = 1e-5
self.max_relative_error = 1e-5
self.dtype = "float64"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = True
def initTestCase(self):
np.random.seed(123)
self.D = reduce(
mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1
)
self.scale_shape = [self.D]
x = np.random.random(self.x_shape).astype(self.dtype)
scale = (
np.random.random(self.scale_shape).astype(self.dtype)
if self.has_scale
else None
)
bias = (
np.random.random(self.scale_shape).astype(self.dtype)
if self.has_bias
else None
)
self.inputs = {
"X": x,
}
self.check_grad_input_list = ['X']
if self.has_scale:
self.inputs.update({"Scale": scale})
self.check_grad_input_list.append('Scale')
if self.has_bias:
self.inputs.update({"Bias": bias})
self.check_grad_input_list.append('Bias')
self.attrs = {
"epsilon": self.epsilon,
"begin_norm_axis": self.begin_norm_axis,
}
y, mean, variance = _reference_layer_norm_naive(
x, scale, bias, self.epsilon, self.begin_norm_axis
)
self.outputs = {
"Y": y,
"Mean": mean,
"Variance": variance,
}
class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-6
self.rev_comp_rtol = 1e-6
self.fw_comp_atol = 1e-7
self.fw_comp_rtol = 1e-7
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.cinn_atol = 1e-5
self.cinn_rtol = 1e-5
self.max_relative_error = 1e-5
self.dtype = "float64"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = False
class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-7
self.rev_comp_rtol = 1e-7
self.fw_comp_atol = 1e-7
self.fw_comp_rtol = 1e-7
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.cinn_atol = 1e-5
self.cinn_rtol = 1e-5
self.max_relative_error = 1e-5
self.dtype = "float64"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = False
class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-6
self.rev_comp_rtol = 1e-6
self.fw_comp_atol = 1e-7
self.fw_comp_rtol = 1e-7
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.cinn_atol = 1e-5
self.cinn_rtol = 1e-5
self.max_relative_error = 1e-5
self.dtype = "float64"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = True
class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.max_relative_error = 7e-3
self.dtype = "float32"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = True
class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.max_relative_error = 1e-5
self.dtype = "float32"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = False
class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.max_relative_error = 3e-3
self.dtype = "float32"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = False
class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest):
def initConfig(self):
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.ori_atol = 1e-4
self.ori_rtol = 1e-4
self.max_relative_error = 1e-3
self.dtype = "float32"
self.x_shape = [2, 6, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = False
self.has_bias = True
class TestLayerNormOp(unittest.TestCase): class TestLayerNormOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.use_cudnn = True self.use_cudnn = True
......
...@@ -79,7 +79,9 @@ class TestSpaceToDepthOp(OpTest): ...@@ -79,7 +79,9 @@ class TestSpaceToDepthOp(OpTest):
if fluid.core.is_compiled_with_cuda() if fluid.core.is_compiled_with_cuda()
else fluid.core.CPUPlace() else fluid.core.CPUPlace()
) )
self.check_output_with_place(place, 1e-5, None, False) self.check_output_with_place(
place=place, atol=1e-5, no_check_set=None, equal_nan=False
)
def test_check_grad(self): def test_check_grad(self):
place = ( place = (
......
...@@ -38,4 +38,5 @@ no_check_set_white_list = [ ...@@ -38,4 +38,5 @@ no_check_set_white_list = [
'einsum', 'einsum',
'rmsprop', 'rmsprop',
'rrelu', 'rrelu',
'layer_norm',
] ]
...@@ -52,9 +52,11 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ ...@@ -52,9 +52,11 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'cholesky_solve', 'cholesky_solve',
'solve', 'solve',
'qr', 'qr',
'layer_norm',
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = [ NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = [
'bilinear_interp', 'bilinear_interp',
'bilinear_interp_v2', 'bilinear_interp_v2',
'layer_norm',
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册