# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from functools import reduce from operator import mul import numpy as np from eager_op_test import ( OpTest, _set_use_system_allocator, convert_float_to_uint16, ) import paddle import paddle.nn.functional as F from paddle import fluid from paddle.fluid import Program, core, program_guard from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 paddle.enable_static() np.random.seed(123) paddle.seed(123) _set_use_system_allocator(True) def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): x_shape = x.shape N = reduce(mul, x_shape[0:begin_norm_axis], 1) D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) x.shape = [N, D] mean = np.mean(x, axis=1) var = np.var(x, axis=1) + epsilon output = np.divide( (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1]) ) if scale is not None: output = scale.reshape([1, D]) * output if beta is not None: output = output + beta.reshape([1, D]) x.shape, output.shape = x_shape, x_shape return output, mean, var def _reference_layer_norm_grad( x, grad_y, scale, bias, mean, var, begin_norm_axis=1 ): x_shape = x.shape N = reduce(mul, x_shape[0:begin_norm_axis], 1) D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) if scale is not None: scale_shape = scale.shape scale.shape = [1, D] x.shape, grad_y.shape = [N, D], [N, D] var.shape, mean.shape = [N, 1], [N, 1] # d_bias if bias is not None: d_bias = np.sum(grad_y, axis=0).reshape([1, D]) else: d_bias = None # d_scale if scale is not None: d_scale = np.sum( ((x - mean) * np.sqrt(1 / var)) * grad_y, axis=0 ).reshape([1, D]) else: d_scale = None # dx if scale is not None: dx_end = scale * np.sqrt(1.0 / var) * grad_y d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( [N, 1] ) # the second part equals to zero. d_mean = 1.0 / D * d_mean_0 d_std = np.sum( -(1.0 / var) * (x - mean) * grad_y * scale, axis=1 ).reshape([N, 1]) * ( 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean) ) else: dx_end = 1.0 * np.sqrt(1.0 / var) * grad_y d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * 1.0, axis=1).reshape( [N, 1] ) # the second part equals to zero. d_mean = 1.0 / D * d_mean_0 d_std = np.sum( -(1.0 / var) * (x - mean) * grad_y * 1.0, axis=1 ).reshape([N, 1]) * ( 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean) ) grad_x = dx_end + d_mean + d_std grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape var.shape, mean.shape = [N], [N] if scale is not None: scale.shape = scale_shape return grad_x, d_scale, d_bias def layer_norm_wrapper( x, scale=None, bias=None, epsilon=1e-05, begin_norm_axis=1 ): input_shape = list(x.shape) normalized_shape = input_shape[begin_norm_axis:] return paddle.nn.functional.layer_norm( x, normalized_shape, weight=scale, bias=bias, epsilon=epsilon ) @unittest.skipIf( paddle.is_compiled_with_rocm(), "ROCm doesn't support fp64 LayerNormOpByOp currently", ) class TestLayerNormOpByOpTest(OpTest): def setUp(self): self.python_api = layer_norm_wrapper self.public_python_api = layer_norm_wrapper self.op_type = "layer_norm" self.prim_op_type = "comp" self.python_out_sig = ["Y"] self.initConfig() self.initTestCase() def test_check_output(self): self.check_output( no_check_set=["Mean", "Variance"], atol=self.ori_atol, rtol=self.ori_rtol, check_prim=True, ) def test_check_grad(self): self.check_grad( self.check_grad_input_list, ['Y'], max_relative_error=self.max_relative_error, check_prim=True, ) def initConfig(self): self.rev_comp_atol = 1e-7 self.rev_comp_rtol = 1e-7 self.fw_comp_atol = 1e-6 self.fw_comp_rtol = 1e-6 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.cinn_atol = 1e-5 self.cinn_rtol = 1e-5 self.max_relative_error = 1e-5 # ROCm does not have float64 LayerNorm kernel self.dtype = "float64" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True def initTestCase(self): np.random.seed(123) self.D = reduce( mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1 ) self.scale_shape = [self.D] x = np.random.random(self.x_shape).astype(self.dtype) scale = ( np.random.random(self.scale_shape).astype(self.dtype) if self.has_scale else None ) bias = ( np.random.random(self.scale_shape).astype(self.dtype) if self.has_bias else None ) self.inputs = { "X": x, } self.check_grad_input_list = ['X'] if self.has_scale: self.inputs.update({"Scale": scale}) self.check_grad_input_list.append('Scale') if self.has_bias: self.inputs.update({"Bias": bias}) self.check_grad_input_list.append('Bias') self.attrs = { "epsilon": self.epsilon, "begin_norm_axis": self.begin_norm_axis, } y, mean, variance = _reference_layer_norm_naive( x, scale, bias, self.epsilon, self.begin_norm_axis ) self.outputs = { "Y": y, "Mean": mean, "Variance": variance, } @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestLayerNormBF16OpByOpTest(OpTest): def setUp(self): self.python_api = layer_norm_wrapper self.public_python_api = layer_norm_wrapper self.op_type = "layer_norm" self.prim_op_type = "comp" self.python_out_sig = ["Y"] self.initConfig() self.initTestCase() def test_check_output(self): self.check_output_with_place( place=core.CUDAPlace(0), no_check_set=["Mean", "Variance"], atol=self.ori_atol, rtol=self.ori_rtol, check_prim=True, ) def test_check_grad(self): self.check_grad_with_place( core.CUDAPlace(0), self.check_grad_input_list, ['Y'], max_relative_error=self.max_relative_error, check_prim=True, ) def initConfig(self): self.ori_atol = 1e-2 self.ori_rtol = 1e-2 self.max_relative_error = 1e-5 self.dtype = np.uint16 self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True def initTestCase(self): np.random.seed(123) self.D = reduce( mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1 ) self.scale_shape = [self.D] x = np.random.random(self.x_shape).astype("float32") scale = ( np.random.random(self.scale_shape).astype("float32") if self.has_scale else None ) bias = ( np.random.random(self.scale_shape).astype("float32") if self.has_bias else None ) self.inputs = { "X": convert_float_to_uint16(x), } self.check_grad_input_list = ['X'] if self.has_scale: self.inputs.update({"Scale": convert_float_to_uint16(scale)}) self.check_grad_input_list.append('Scale') if self.has_bias: self.inputs.update({"Bias": convert_float_to_uint16(bias)}) self.check_grad_input_list.append('Bias') self.attrs = { "epsilon": self.epsilon, "begin_norm_axis": self.begin_norm_axis, } y, mean, variance = _reference_layer_norm_naive( x, scale, bias, self.epsilon, self.begin_norm_axis ) self.outputs = { "Y": convert_float_to_uint16(y), "Mean": convert_float_to_uint16(mean), "Variance": convert_float_to_uint16(variance), } @unittest.skipIf( paddle.is_compiled_with_rocm(), "ROCm doesn't support fp64 LayerNormOpByOp currently", ) class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-6 self.rev_comp_rtol = 1e-6 self.fw_comp_atol = 1e-7 self.fw_comp_rtol = 1e-7 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.cinn_atol = 1e-5 self.cinn_rtol = 1e-5 self.max_relative_error = 1e-5 self.dtype = "float64" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestLayerNormBF16OpByOpTest_case2(TestLayerNormBF16OpByOpTest): def initConfig(self): self.ori_atol = 1e-2 self.ori_rtol = 1e-2 self.max_relative_error = 1e-5 self.dtype = np.uint16 self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False @unittest.skipIf( paddle.is_compiled_with_rocm(), "ROCm doesn't support fp64 LayerNormOpByOp currently", ) class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-7 self.rev_comp_rtol = 1e-7 self.fw_comp_atol = 1e-7 self.fw_comp_rtol = 1e-7 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.cinn_atol = 1e-5 self.cinn_rtol = 1e-5 self.max_relative_error = 1e-5 self.dtype = "float64" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestLayerNormBF16OpByOpTest_case3(TestLayerNormBF16OpByOpTest): def initConfig(self): self.ori_atol = 1e-2 self.ori_rtol = 1e-2 self.max_relative_error = 1e-5 self.dtype = np.uint16 self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False @unittest.skipIf( paddle.is_compiled_with_rocm(), "ROCm doesn't support fp64 LayerNormOpByOp currently", ) class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-6 self.rev_comp_rtol = 1e-6 self.fw_comp_atol = 1e-7 self.fw_comp_rtol = 1e-7 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.cinn_atol = 1e-5 self.cinn_rtol = 1e-5 self.max_relative_error = 1e-5 self.dtype = "float64" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestLayerNormBF16OpByOpTest_case4(TestLayerNormBF16OpByOpTest): def initConfig(self): self.ori_atol = 1e-2 self.ori_rtol = 1e-2 self.max_relative_error = 1e-5 self.dtype = np.uint16 self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-5 self.rev_comp_rtol = 1e-5 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.max_relative_error = 7e-3 self.dtype = "float32" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = True class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-5 self.rev_comp_rtol = 1e-5 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.max_relative_error = 1e-5 self.dtype = "float32" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = False class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-5 self.rev_comp_rtol = 1e-5 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.max_relative_error = 3e-3 self.dtype = "float32" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = True self.has_bias = False class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest): def initConfig(self): self.rev_comp_atol = 1e-5 self.rev_comp_rtol = 1e-5 self.ori_atol = 1e-4 self.ori_rtol = 1e-4 self.max_relative_error = 1e-3 self.dtype = "float32" self.x_shape = [2, 6, 6, 3] self.epsilon = 0.00001 self.begin_norm_axis = 1 self.has_scale = False self.has_bias = True class TestLayerNormOp(unittest.TestCase): def setUp(self): self.use_cudnn = True def __assert_close(self, tensor, np_array, msg, atol=1e-4): np.testing.assert_allclose( np.array(tensor).flatten(), np_array.flatten(), rtol=1e-3, atol=atol, err_msg=msg, ) def check_forward_backward( self, shape, begin_norm_axis, has_scale=True, has_bias=True, y_grad_scale=1.0, use_mkldnn=False, ): def test_with_place( place, shape, begin_norm_axis, use_mkldnn=use_mkldnn ): # attr epsilon = 0.00001 x_shape = shape D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) scale_shape = [D] np.random.seed(123) x = np.random.random_sample(x_shape).astype(np.float32) scale = ( np.random.random_sample(scale_shape).astype(np.float32) if has_scale else None ) bias = ( np.random.random_sample(scale_shape).astype(np.float32) if has_bias else None ) y_grad = (np.random.random_sample(x_shape) * y_grad_scale).astype( np.float32 ) # reference forward & backward y, mean, variance = _reference_layer_norm_naive( x, scale, bias, epsilon, begin_norm_axis ) x_grad, scale_grad, bias_grad = _reference_layer_norm_grad( x, y_grad, scale, bias, mean, variance, begin_norm_axis ) var_dict = locals() var_dict['y@GRAD'] = y_grad var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD'] if has_scale: var_names += ['scale'] if has_bias: var_names += ['bias'] ground_truth = {name: var_dict[name] for name in var_names} program = fluid.Program() with fluid.program_guard(program): block = program.global_block() for name in ground_truth: block.create_var( name=name, dtype='float32', shape=ground_truth[name].shape, ) inputs = {"X": block.var('x')} fetch_list = [ 'y', 'mean', 'variance', 'x@GRAD', ] if has_scale: inputs["Scale"] = block.var('scale') fetch_list += ['scale@GRAD'] if has_bias: inputs["Bias"] = block.var('bias') fetch_list += ['bias@GRAD'] layer_norm_op = block.append_op( type="layer_norm", inputs=inputs, outputs={ "Y": block.var('y'), "Mean": block.var('mean'), # share the same memory "Variance": block.var( 'variance' ), # share the same memory }, attrs={ "epsilon": epsilon, "begin_norm_axis": begin_norm_axis, "use_mkldnn": use_mkldnn, }, ) # generate backward op_desc grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( layer_norm_op.desc, set(), [] ) grad_op_desc = grad_op_desc_list[0] new_op_desc = block.desc.append_op() new_op_desc.copy_from(grad_op_desc) for var_name in grad_op_desc.output_arg_names(): block.desc.var(var_name.encode("ascii")) grad_op_desc.infer_var_type(block.desc) grad_op_desc.infer_shape(block.desc) for arg in grad_op_desc.output_arg_names(): grad_var = block.desc.find_var(arg.encode("ascii")) grad_var.set_dtype(core.VarDesc.VarType.FP32) program._sync_with_cpp() exe = fluid.Executor(place) out = exe.run( program, feed={ name: var_dict[name] for name in ['x', 'scale', 'bias', 'y@GRAD'] }, fetch_list=fetch_list, ) # print(y) # print(out[0]) self.__assert_close(y, out[0], "y") self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(x_grad, out[3], "x_grad") if has_scale: self.__assert_close( scale_grad, out[fetch_list.index('scale@GRAD')], "scale_grad", 1e-3, ) if has_bias: self.__assert_close( bias_grad, out[fetch_list.index('bias@GRAD')], "bias_grad", ) places = [core.CPUPlace()] if ( core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm") and self.use_cudnn ): places.append(core.CUDAPlace(0)) for place in places: test_with_place(place, shape, begin_norm_axis) def test_check_forward_backward_with_scale_and_bias(self): self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward( shape=[2, 3, 4, 5], begin_norm_axis=1, has_scale=False, has_bias=True, ) self.check_forward_backward( shape=[2, 3, 4, 5], begin_norm_axis=1, has_scale=True, has_bias=False, ) self.check_forward_backward( shape=[2, 3, 4, 5], begin_norm_axis=1, has_scale=False, has_bias=False, ) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) self.check_forward_backward( shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1 ) self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2) self.check_forward_backward(shape=[3, 2, 1133], begin_norm_axis=2) self.check_forward_backward( shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1 ) self.check_forward_backward( shape=[92, 513, 1134], begin_norm_axis=2, has_scale=False, has_bias=True, y_grad_scale=0.1, ) self.check_forward_backward( shape=[92, 513, 1134], begin_norm_axis=2, has_scale=True, has_bias=False, y_grad_scale=0.1, ) self.check_forward_backward( shape=[92, 513, 1134], begin_norm_axis=2, has_scale=False, has_bias=False, y_grad_scale=0.1, ) self.check_forward_backward( shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True ) self.check_forward_backward( shape=[1, 128, 256, 256], begin_norm_axis=3, has_scale=True, has_bias=True, ) self.check_forward_backward( shape=[1, 256, 384], begin_norm_axis=2, has_scale=True, has_bias=True, ) class TestLayerNormAPI(unittest.TestCase): def test_case(self): x = paddle.static.data(name='x', shape=[64, 32, 256], dtype='float32') x = paddle.static.nn.layer_norm( x, scale=True, shift=True, begin_norm_axis=1, epsilon=1e-05, param_attr=None, bias_attr=None, ) x = paddle.static.nn.layer_norm( x, scale=False, shift=False, begin_norm_axis=1, epsilon=1e-05, param_attr=None, bias_attr=None, ) x = paddle.static.nn.layer_norm( x, scale=True, shift=True, begin_norm_axis=1, epsilon=1e-05, param_attr="scale", bias_attr="shift", ) class TestDygraphLayerNormAPIError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): paddle.enable_static() layer_norm = paddle.nn.LayerNorm([32, 32]) # the input of LayerNorm must be Variable. x1 = np.random.random((3, 32, 32)).astype('float32') self.assertRaises(TypeError, layer_norm, x1) # the input dtype of LayerNorm must be float32 or float64 # float16 only can be set on GPU place x2 = paddle.static.data( name='x2', shape=[-1, 3, 32, 32], dtype="int32" ) self.assertRaises(TypeError, layer_norm, x2) @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA or not support the float16", ) class TestFP16ScaleBiasLayerNorm(unittest.TestCase): def check_main(self, x_np, weight_np, bias_np, dtype): paddle.disable_static() weight_np = weight_np.astype(dtype) bias_np = bias_np.astype(dtype) x = paddle.to_tensor(x_np) weight = paddle.to_tensor(weight_np) bias = paddle.to_tensor(bias_np) x.stop_gradient = False weight.stop_gradient = False bias.stop_gradient = False y = F.layer_norm(x, x.shape[1:], weight, bias) x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) y_np = y.numpy().astype('float32') x_g_np = x_g.numpy().astype('float32') w_g_np = w_g.numpy().astype('float16') b_g_np = b_g.numpy().astype('float32') paddle.enable_static() return y_np, x_g_np, w_g_np, b_g_np def test_main(self): x_np = np.random.random([10, 20]).astype('float16') weight_np = np.random.random([20]).astype('float16') bias_np = np.random.random([20]).astype('float16') y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( x_np, weight_np, bias_np, 'float16' ) y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( x_np, weight_np, bias_np, 'float32' ) def assert_equal(x, y): np.testing.assert_array_equal(x, y) assert_equal(y_np_1, y_np_2) assert_equal(x_g_np_1, x_g_np_2) assert_equal(w_g_np_1, w_g_np_2) assert_equal(b_g_np_1, b_g_np_2) @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestBF16ScaleBiasLayerNorm(unittest.TestCase): def check_main(self, x_np, weight_np, bias_np, dtype): paddle.disable_static() x = paddle.to_tensor(x_np) weight = paddle.to_tensor(weight_np) bias = paddle.to_tensor(bias_np) if dtype == "bfloat16": x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16) x.stop_gradient = False weight.stop_gradient = False bias.stop_gradient = False y = F.layer_norm(x, x.shape[1:], weight, bias) x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) y_np = y.cast('float32').numpy() x_g_np = x_g.cast('float32').numpy() w_g_np = w_g.cast('float32').numpy() b_g_np = b_g.cast('float32').numpy() paddle.enable_static() return y_np, x_g_np, w_g_np, b_g_np def test_main(self): x_np = np.random.random([10, 20]).astype('float32') weight_np = np.random.random([20]).astype('float32') bias_np = np.random.random([20]).astype('float32') y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( x_np, weight_np, bias_np, 'float32' ) y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( x_np, weight_np, bias_np, 'bfloat16' ) def assert_equal(x, y): np.testing.assert_allclose(x, y, rtol=1e-05, atol=3e-2) assert_equal(y_np_1, y_np_2) assert_equal(x_g_np_1, x_g_np_2) assert_equal(w_g_np_1, w_g_np_2) assert_equal(b_g_np_1, b_g_np_2) class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): def test_main(self): self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) _keep_layer_norm_scale_bias_to_fp32(False) self.assertFalse(_keep_layer_norm_scale_bias_to_fp32()) _keep_layer_norm_scale_bias_to_fp32(True) self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(), "core is not compiled with CUDA or not support the FastMath", ) class TestFastMathLayerNormOp(unittest.TestCase): def check_layer_norm( self, dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias ): paddle.disable_static() epsilon = 0.00001 x = paddle.to_tensor(x_np) if dtype == "bfloat16": x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16) x.stop_gradient = True bias = paddle.to_tensor(bias_np) if has_scale else None scale = paddle.to_tensor(scale_np) if has_bias else None if bias is not None: bias.stop_gradient = True if scale is not None: scale.stop_gradient = True y = F.layer_norm(x, x.shape[norm_axis:], scale, bias) y_np = y.cast('float32').numpy() paddle.enable_static() return y_np def check_with_fast_math( self, dtype, shape, norm_axis, has_scale, has_bias ): def use_fast_math(enabled): paddle.set_flags({'FLAGS_use_fast_math': enabled}) def __assert_close(x, y): np.testing.assert_allclose(x, y, rtol=1e-05, atol=1e-04) x_np = np.random.random(shape).astype('float32') bias_np = np.random.random(shape[norm_axis:]).astype('float32') scale_np = np.random.random(shape[norm_axis:]).astype('float32') use_fast_math(False) y_fast = self.check_layer_norm( dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias ) use_fast_math(True) y_dev = self.check_layer_norm( dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias ) __assert_close(y_fast, y_dev) def check_with_dtype(self, dtype): self.check_with_fast_math( dtype, shape=[17, 129], norm_axis=1, has_scale=False, has_bias=True, ) self.check_with_fast_math( dtype, shape=[8, 512], norm_axis=1, has_scale=False, has_bias=False, ) self.check_with_fast_math( dtype, shape=[2, 768], norm_axis=1, has_scale=False, has_bias=False, ) def init_dtype(self): self.dtype = 'float32' def test_main(self): self.init_dtype() self.check_with_dtype(dtype=self.dtype) @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm() or not core.is_bfloat16_supported(core.CUDAPlace(0)), "core is not compiled with CUDA or not support the bfloat16", ) class TestFastMathLayerNormBF16Op(TestFastMathLayerNormOp): def init_dtype(self): self.dtype = 'bfloat16' if __name__ == '__main__': paddle.enable_static() unittest.main()