From f7eb03c654e0a404cf0e87c071dc0bac95202252 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Wed, 14 Jun 2023 14:53:13 +0800 Subject: [PATCH] support group_norm and cumsum prim ops bf16 dtype (#54580) --- .../composite_backward_api.h | 20 +- paddle/phi/core/visit_type.h | 13 +- paddle/phi/kernels/gpu/reduce.h | 3 +- .../incubate/autograd/composite_rules.py | 9 +- test/legacy_test/CMakeLists.txt | 2 +- test/legacy_test/test_cumsum_op.py | 20 +- test/legacy_test/test_group_norm_op.py | 175 +++++++++++++++--- 7 files changed, 188 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 692dda588df..10bf11e3a77 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -693,11 +693,13 @@ void group_norm_grad(const Tensor& x, Tensor x_data = x; Tensor out_grad_data = out_grad; - if (x.dtype() == phi::DataType::FLOAT16) { + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { x_data = cast(x, phi::DataType::FLOAT32); } - if (out_grad.dtype() == phi::DataType::FLOAT16) { + if (out_grad.dtype() == phi::DataType::FLOAT16 || + out_grad.dtype() == phi::DataType::BFLOAT16) { out_grad_data = cast(out_grad, phi::DataType::FLOAT32); } @@ -728,7 +730,8 @@ void group_norm_grad(const Tensor& x, Tensor p1; if (scale_ptr) { auto scale_data = scale.get(); - if (scale_data.dtype() == phi::DataType::FLOAT16) { + if (scale_data.dtype() == phi::DataType::FLOAT16 || + scale_data.dtype() == phi::DataType::BFLOAT16) { scale_data = cast(scale_data, phi::DataType::FLOAT32); } d1 = (reshape(sum_y_grad_mul_x * scale_data, shape_group)) @@ -757,7 +760,8 @@ void group_norm_grad(const Tensor& x, auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; auto x_grad_data = tmp_1 + tmp_2; x_grad_data = reshape(x_grad_data, x.shape()); - if (x.dtype() == phi::DataType::FLOAT16) { + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { x_grad_data = cast(x_grad_data, x.dtype()); } @@ -770,9 +774,9 @@ void group_norm_grad(const Tensor& x, reshape(sum_y_grad, shape_group) * reshape(mean, third_shape)) * reshape(inv_std, third_shape); - auto scale_grad_tmp = - reshape(tmp1.sum(std::vector({0}), dtype, false), - IntArray(std::vector({C}))); + auto scale_grad_tmp = reshape( + tmp1.sum(std::vector({0}), scale_ptr->dtype(), false), + IntArray(std::vector({C}))); set_output(scale_grad_tmp, scale_grad); } else { scale_grad = nullptr; @@ -782,7 +786,7 @@ void group_norm_grad(const Tensor& x, if (bias_grad) { if (bias_ptr) { auto bias_grad_tmp = - sum_y_grad.sum(std::vector({0}), dtype, false); + sum_y_grad.sum(std::vector({0}), bias_ptr->dtype(), false); set_output(bias_grad_tmp, bias_grad); } else { bias_grad = nullptr; diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 8343343a361..f96fdb1f28b 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -298,8 +298,13 @@ namespace phi { } \ }() -#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \ - SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \ +#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(SPECIFIED_TYPE1, \ + SPECIFIED_TYPE2, \ + SPECIFIED_TYPE3, \ + SPECIFIED_TYPE4, \ + TYPE, \ + NAME, \ + ...) \ [&] { \ const auto& __dtype__ = TYPE; \ switch (__dtype__) { \ @@ -328,6 +333,10 @@ namespace phi { SPECIFIED_TYPE3, \ ::phi::DataTypeToCppType::type, \ __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + SPECIFIED_TYPE4, \ + ::phi::DataTypeToCppType::type, \ + __VA_ARGS__) \ default: \ PD_THROW("function " #NAME " is not implemented for data type `", \ __dtype__, \ diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 0d6edd13ac9..5ceb81eabd8 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -47,10 +47,11 @@ void Reduce(const KPDevice& dev_ctx, #ifndef PADDLE_WITH_XPU_KP if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) { auto tmp_tensor = phi::Cast(dev_ctx, x, out_dtype); - PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( + PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES( phi::DataType::INT32, phi::DataType::INT64, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, out_dtype, "ReduceKernel", ([&] { diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index a458cc73b5f..fa036dc1cc3 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -656,8 +656,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout): is_amp = False from paddle.fluid.data_feeder import convert_dtype - # when inputs are float16, convert to float32 in computing - if convert_dtype(x.dtype) == "float16": + dtype = convert_dtype(x.dtype) + # when inputs are float16 or bfloat16, convert to float32 in computing + if dtype in ["float16", "uint16"]: is_amp = True x = cast(x, "float32") scale = cast(scale, "float32") @@ -676,9 +677,9 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout): out = out + reshape(bias, (-1, 1, 1)) ret_mean_ = reshape(mean_, (N, groups)) ret_var_ = reshape(var_, (N, groups)) - # return output in float16, mean and var in float32 + # return output in float16 or bfloat16, mean and var in float32 if is_amp: - out = cast(out, "float16") + out = cast(out, dtype) return out, ret_mean_, ret_var_ diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index c39eb7e3852..525cac1c5e2 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1033,7 +1033,7 @@ set_tests_properties( PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120) set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120) -set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300) set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250) set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250) set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120) diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 0b9ac537f3a..7bb5e41f23c 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -122,7 +122,7 @@ class TestSumOp1(OpTest): self.prim_op_type = "prim" self.python_api = cumsum_wrapper self.public_python_api = paddle.cumsum - self.set_enable_cinn() + self.if_enable_cinn() self.init_dtype() self.set_attrs_input_output() if self.dtype == np.uint16: @@ -141,7 +141,7 @@ class TestSumOp1(OpTest): def init_dtype(self): self.dtype = self.dtype_ = np.float64 - def set_enable_cinn(self): + def if_enable_cinn(self): pass def set_attrs_input_output(self): @@ -221,7 +221,7 @@ class TestSumOpExclusive1(OpTest): self.prim_op_type = "prim" self.python_api = cumsum_wrapper self.public_python_api = paddle.cumsum - self.set_enable_cinn() + self.if_enable_cinn() self.init_dtype() self.set_attrs_input_output() if self.dtype == np.uint16: @@ -240,7 +240,7 @@ class TestSumOpExclusive1(OpTest): def init_dtype(self): self.dtype = self.dtype_ = np.float64 - def set_enable_cinn(self): + def if_enable_cinn(self): pass def set_attrs_input_output(self): @@ -346,7 +346,7 @@ class TestSumOpReverseExclusive(OpTest): self.prim_op_type = "prim" self.python_api = cumsum_wrapper self.public_python_api = paddle.cumsum - self.set_enable_cinn() + self.if_enable_cinn() self.init_dtype() self.attrs = { 'axis': 2, @@ -378,7 +378,7 @@ class TestSumOpReverseExclusive(OpTest): def init_dtype(self): self.dtype = self.dtype_ = np.float64 - def set_enable_cinn(self): + def if_enable_cinn(self): pass @@ -387,7 +387,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2): def init_dtype(self): self.dtype = self.dtype_ = np.float16 - def set_enable_cinn(self): + def if_enable_cinn(self): pass def test_check_output(self): @@ -430,7 +430,7 @@ def create_test_bf16_class(parent): self.dtype = np.uint16 self.dtype_ = np.float32 - def set_enable_cinn(self): + def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): @@ -439,7 +439,9 @@ def create_test_bf16_class(parent): def test_check_grad(self): place = paddle.CUDAPlace(0) - self.check_grad_with_place(place, ["X"], "Out", check_prim=True) + self.check_grad_with_place( + place, ["X"], "Out", check_prim=True, numeric_grad_delta=0.05 + ) cls_name = "{}_{}".format(parent.__name__, "BF16") TestCumsumBF16Op.__name__ = cls_name diff --git a/test/legacy_test/test_group_norm_op.py b/test/legacy_test/test_group_norm_op.py index 83c2e265823..ef1c5e6384e 100644 --- a/test/legacy_test/test_group_norm_op.py +++ b/test/legacy_test/test_group_norm_op.py @@ -19,6 +19,7 @@ import parameterized as param from eager_op_test import ( OpTest, convert_float_to_uint16, + convert_uint16_to_float, paddle_static_guard, skip_check_grad_ci, ) @@ -810,6 +811,96 @@ def apply_to_static(net, use_cinn): [[1e-3, 1e-3, 1e-3]], # gpu thresholds for static, jit, jit_cinn None, ), + ( + 'test0_bfp16', + (2, 100, 3, 5), + 1e-5, + 2, + 'NCHW', + places, + 'bfloat16', + [ + [ + 1e-2, + 1e-2, + 1e-2, + ], # cpu thresholds for static, jit, jit_cinn + [1e-2, 1e-2, 1e-2], + ], # gpu thresholds for static, jit, jit_cinn + None, + ), + ( + 'test1_bfp16', + (2, 100, 3, 5), + 1e-5, + 1, + 'NCHW', + places, + 'bfloat16', + [ + [ + 1e-2, + 1e-2, + 1e-2, + ], # cpu thresholds for static, jit, jit_cinn + [1e-2, 1e-2, 1e-2], + ], # gpu thresholds for static, jit, jit_cinn + None, + ), + ( + 'test2_bfp16', + (2, 100, 3, 5), + 1e-5, + 4, + 'NCHW', + places, + 'bfloat16', + [ + [ + 1e-2, + 1e-2, + 1e-2, + ], # cpu thresholds for static, jit, jit_cinn + [1e-2, 1e-2, 1e-2], + ], # gpu thresholds for static, jit, jit_cinn + None, + ), + ( + 'bigeps3_bfp16', + (2, 100, 3, 5), + 0.5, + 2, + 'NCHW', + places, + 'bfloat16', + [ + [ + 1e-2, + 1e-2, + 1e-2, + ], # cpu thresholds for static, jit, jit_cinn + [1e-2, 1e-2, 1e-2], + ], # gpu thresholds for static, jit, jit_cinn + None, + ), + ( + 'largedata_bfp16', + (2, 32, 64, 64), + 1e-5, + 4, + 'NCHW', + places, + 'bfloat16', + [ + [ + 1e-2, + 1e-2, + 1e-2, + ], # cpu thresholds for static, jit, jit_cinn + [1e-2, 1e-2, 1e-2], + ], # gpu thresholds for static, jit, jit_cinn + None, + ), ), ) class TestCompositeGroupNorm(unittest.TestCase): @@ -825,12 +916,23 @@ class TestCompositeGroupNorm(unittest.TestCase): np.random.seed(1234) self.fwd_desire = [] self.rev_desire = [] - self.x = np.random.random(self.shape).astype(self.dtype) - self.scale = np.random.random([self.shape[1]]).astype(self.dtype) - self.bias = np.random.random([self.shape[1]]).astype(self.dtype) + if self.dtype != "bfloat16": + self.x = np.random.random(self.shape).astype(self.dtype) + self.scale = np.random.random([self.shape[1]]).astype(self.dtype) + self.bias = np.random.random([self.shape[1]]).astype(self.dtype) + else: + self.x = convert_float_to_uint16( + np.random.random(self.shape).astype("float32") + ) + self.scale = convert_float_to_uint16( + np.random.random([self.shape[1]]).astype("float32") + ) + self.bias = convert_float_to_uint16( + np.random.random([self.shape[1]]).astype("float32") + ) self.num_channels = self.shape[1] - if self.dtype == 'float16': + if self.dtype in ['float16', 'bfloat16']: self.places = [] if paddle.is_compiled_with_cuda(): self.places.append(paddle.CUDAPlace(0)) @@ -879,7 +981,11 @@ class TestCompositeGroupNorm(unittest.TestCase): paddle.assign(bias_, group_norm.bias) output = group_norm(input_) grad = paddle.grad(output, input_) - + if self.dtype == "bfloat16": + output = paddle.cast(output, "float32") + grad = paddle.utils.map_structure( + lambda x: paddle.cast(x, "float32"), grad + ) return output, grad[0] def get_static_desire(self, place): @@ -923,7 +1029,6 @@ class TestCompositeGroupNorm(unittest.TestCase): output = group_norm(input_) blocks = mp.blocks - names = dict( zip( blocks[0].ops[2].output_names, @@ -964,7 +1069,11 @@ class TestCompositeGroupNorm(unittest.TestCase): ) paddle.disable_static() core._set_prim_all_enabled(True) - + if self.dtype == "bfloat16": + out_list[0] = convert_uint16_to_float(out_list[0]) + i = 3 + for i in range(3, len(out_list)): + out_list[i] = convert_uint16_to_float(out_list[i]) return out_list[:3], out_list[3:] def test_static_comp(self): @@ -1051,6 +1160,11 @@ class TestCompositeGroupNorm(unittest.TestCase): }, fetch_list=vars_list + [grads], ) + if self.dtype == "bfloat16": + out_list[0] = convert_uint16_to_float(out_list[0]) + i = 3 + for i in range(3, len(out_list)): + out_list[i] = convert_uint16_to_float(out_list[i]) fwd_actual[-1].append(out_list[0]) fwd_actual[-1].append(out_list[1]) fwd_actual[-1].append(out_list[2]) @@ -1075,12 +1189,14 @@ class TestCompositeGroupNorm(unittest.TestCase): atol = self.threshold_list[i][0] rtol = self.threshold_list[i][0] for j in range(len(self.static_fwd_desire[i])): - # in float16 type, Y is float16, mean and var are float16 + # in float16 type, Y is float16, mean and var are float32 # so check mean and var with float32 gpu threshold - if self.dtype == 'float16' and j > 0: + if self.dtype == "float16" and j > 0: atol = 1e-5 rtol = 1e-5 - + elif self.dtype == "bfloat16" and j > 0: + atol = 5e-3 + rtol = 5e-3 np.testing.assert_allclose( self.static_fwd_desire[i][j], fwd_actual[i][j], @@ -1091,13 +1207,6 @@ class TestCompositeGroupNorm(unittest.TestCase): max_abs_diff = np.max( np.abs(self.static_fwd_desire[i][j] - fwd_actual[i][j]) ) - print( - self.shape, - self.dtype, - self.places[i], - vars_name[j], - max_abs_diff, - ) # compare with eager_desire np.testing.assert_allclose( self.fwd_desire[i], @@ -1121,14 +1230,6 @@ class TestCompositeGroupNorm(unittest.TestCase): np.abs(self.static_rev_desire[i][j] - rev_actual[i][j]) ) - print( - self.shape, - self.dtype, - self.places[i], - vars_name[j + 3], - max_abs_diff, - ) - np.testing.assert_allclose( self.static_rev_desire[i][j], rev_actual[i][j], @@ -1183,8 +1284,16 @@ class TestCompositeGroupNorm(unittest.TestCase): net = apply_to_static(net, False) output = net(input_) grad = paddle.grad(output, input_) - fwd_actual.append(output.numpy()) - rev_actual.append(grad[0].numpy()) + fwd_actual.append( + convert_uint16_to_float(output.numpy()) + if self.dtype == "bfloat16" + else output.numpy() + ) + rev_actual.append( + convert_uint16_to_float(grad[0].numpy()) + if self.dtype == "bfloat16" + else grad[0].numpy() + ) for i in range(len(self.places)): atol = self.threshold_list[i][1] @@ -1244,8 +1353,16 @@ class TestCompositeGroupNorm(unittest.TestCase): net = apply_to_static(net, True) output = net(input_) grad = paddle.grad(output, input_) - fwd_actual.append(output.numpy()) - rev_actual.append(grad[0].numpy()) + fwd_actual.append( + convert_uint16_to_float(output.numpy()) + if self.dtype == "bfloat16" + else output.numpy() + ) + rev_actual.append( + convert_uint16_to_float(grad[0].numpy()) + if self.dtype == "bfloat16" + else grad[0].numpy() + ) i = 0 for place in self.places: -- GitLab