diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index ca99b818dbaf01d0395dcf0e064e16b6c849f73f..737ec0ce6e4f1c660b9ec2df45efa5018d69fb42 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -934,8 +934,9 @@ void layer_norm_grad(const Tensor& x, scale_cast = reshape(*scale_ptr, std::vector({1, shape_2})); } - // cast dtype to float32 if dtype =float16 - if (x.dtype() == phi::DataType::FLOAT16) { + // cast dtype to float32 if dtype =float16 or bfloat16 + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { x_cast = cast(x_cast, phi::DataType::FLOAT32); out_grad_cast = cast(out_grad_cast, phi::DataType::FLOAT32); if (scale_ptr) { @@ -967,7 +968,8 @@ void layer_norm_grad(const Tensor& x, auto x_grad_tmp = dx_end - d_mean_d_std; x_grad_tmp = reshape(x_grad_tmp, phi::vectorize(x.dims())); - if (x.dtype() == phi::DataType::FLOAT16) { + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { x_grad_tmp = cast(x_grad_tmp, x.dtype()); } set_output(x_grad_tmp, x_grad); @@ -979,6 +981,10 @@ void layer_norm_grad(const Tensor& x, (x_sub_mean_mul_sqrt_var_1 * out_grad_cast) .sum(std::vector({0}), x_cast.dtype(), true); scale_grad_tmp = reshape(scale_grad_tmp, scale_ptr->shape()); + if (scale_ptr->dtype() == phi::DataType::FLOAT16 || + scale_ptr->dtype() == phi::DataType::BFLOAT16) { + scale_grad_tmp = cast(scale_grad_tmp, scale_ptr->dtype()); + } set_output(scale_grad_tmp, scale_grad); } else { scale_grad = nullptr; @@ -990,6 +996,10 @@ void layer_norm_grad(const Tensor& x, auto bias_grad_tmp = out_grad_cast.sum(std::vector({0}), x_cast.dtype(), true); bias_grad_tmp = reshape(bias_grad_tmp, bias_ptr->shape()); + if (bias_ptr->dtype() == phi::DataType::FLOAT16 || + bias_ptr->dtype() == phi::DataType::BFLOAT16) { + bias_grad_tmp = cast(bias_grad_tmp, bias_ptr->dtype()); + } set_output(bias_grad_tmp, bias_grad); } else { bias_grad = nullptr; diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index a5d616b50641f7ba69ed82aca7ac1cec268517f5..a458cc73b5f86142188bcadac2e1ad8beb0ed91c 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -150,9 +150,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): is_amp = False from paddle.fluid.data_feeder import convert_dtype - if convert_dtype(x.dtype) == "float16": + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: is_amp = True x = cast(x, "float32") + scale = cast(scale, "float32") if scale else scale + bias = cast(bias, "float32") if bias else bias axis = tuple(range(begin_norm_axis, len(x.shape))) mean_ = mean(x, axis=axis, keepdim=True) @@ -175,8 +178,7 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): mean_ = reshape(mean_, [-1]) variance = reshape(variance, [-1]) if is_amp: - out = cast(out, "float16") - + out = cast(out, dtype) return out, mean_, variance @@ -632,7 +634,7 @@ def rsqrt_composite(x): from paddle.fluid.data_feeder import convert_dtype dtype = convert_dtype(x.dtype) - if dtype == "float16" or dtype == "uint16": + if dtype in ["float16", "uint16"]: is_amp = True x = cast(x, "float32") y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index ddbf15294f75e1e4c098c9b13b06dad5f0627caf..22457958f7bfd4dce6f8bbbd6cdc6b48a01da215 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -958,7 +958,7 @@ else() set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) - set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) + set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor diff --git a/test/legacy_test/test_layer_norm_op.py b/test/legacy_test/test_layer_norm_op.py index 75cabf85b7f91805a764d12d99caef0183d71299..6fa2c41da3eeac937e96bf20500a1aa4f824a2ba 100644 --- a/test/legacy_test/test_layer_norm_op.py +++ b/test/legacy_test/test_layer_norm_op.py @@ -17,7 +17,11 @@ from functools import reduce from operator import mul import numpy as np -from eager_op_test import OpTest, _set_use_system_allocator +from eager_op_test import ( + OpTest, + _set_use_system_allocator, + convert_float_to_uint16, +) import paddle import paddle.nn.functional as F @@ -212,6 +216,96 @@ class TestLayerNormOpByOpTest(OpTest): } +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestLayerNormBF16OpByOpTest(OpTest): + def setUp(self): + self.python_api = layer_norm_wrapper + self.public_python_api = layer_norm_wrapper + self.op_type = "layer_norm" + self.prim_op_type = "comp" + self.python_out_sig = ["Y"] + self.initConfig() + self.initTestCase() + + def test_check_output(self): + self.check_output_with_place( + place=core.CUDAPlace(0), + no_check_set=["Mean", "Variance"], + atol=self.ori_atol, + rtol=self.ori_rtol, + check_prim=True, + ) + + def test_check_grad(self): + self.check_grad_with_place( + core.CUDAPlace(0), + self.check_grad_input_list, + ['Y'], + max_relative_error=self.max_relative_error, + check_prim=True, + ) + + def initConfig(self): + self.ori_atol = 1e-2 + self.ori_rtol = 1e-2 + + self.max_relative_error = 1e-5 + + self.dtype = np.uint16 + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = True + + def initTestCase(self): + np.random.seed(123) + + self.D = reduce( + mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1 + ) + self.scale_shape = [self.D] + x = np.random.random(self.x_shape).astype("float32") + scale = ( + np.random.random(self.scale_shape).astype("float32") + if self.has_scale + else None + ) + bias = ( + np.random.random(self.scale_shape).astype("float32") + if self.has_bias + else None + ) + self.inputs = { + "X": convert_float_to_uint16(x), + } + self.check_grad_input_list = ['X'] + + if self.has_scale: + self.inputs.update({"Scale": convert_float_to_uint16(scale)}) + self.check_grad_input_list.append('Scale') + if self.has_bias: + self.inputs.update({"Bias": convert_float_to_uint16(bias)}) + self.check_grad_input_list.append('Bias') + + self.attrs = { + "epsilon": self.epsilon, + "begin_norm_axis": self.begin_norm_axis, + } + y, mean, variance = _reference_layer_norm_naive( + x, scale, bias, self.epsilon, self.begin_norm_axis + ) + self.outputs = { + "Y": convert_float_to_uint16(y), + "Mean": convert_float_to_uint16(mean), + "Variance": convert_float_to_uint16(variance), + } + + class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-6 @@ -234,6 +328,21 @@ class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): self.has_bias = False +class TestLayerNormBF16OpByOpTest_case2(TestLayerNormBF16OpByOpTest): + def initConfig(self): + self.ori_atol = 1e-2 + self.ori_rtol = 1e-2 + + self.max_relative_error = 1e-5 + + self.dtype = np.uint16 + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = False + + class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-7 @@ -256,6 +365,21 @@ class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): self.has_bias = False +class TestLayerNormBF16OpByOpTest_case3(TestLayerNormBF16OpByOpTest): + def initConfig(self): + self.ori_atol = 1e-2 + self.ori_rtol = 1e-2 + + self.max_relative_error = 1e-5 + + self.dtype = np.uint16 + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = True + self.has_bias = False + + class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-6 @@ -278,6 +402,21 @@ class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): self.has_bias = True +class TestLayerNormBF16OpByOpTest_case4(TestLayerNormBF16OpByOpTest): + def initConfig(self): + self.ori_atol = 1e-2 + self.ori_rtol = 1e-2 + + self.max_relative_error = 1e-5 + + self.dtype = np.uint16 + self.x_shape = [2, 6, 6, 3] + self.epsilon = 0.00001 + self.begin_norm_axis = 1 + self.has_scale = False + self.has_bias = True + + class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-5