From d1d43c22c81e12b825cfbadb304435deda0f2389 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:58:05 +0800 Subject: [PATCH] support some prim ops bf16 dtype (#54263) --- .../api/composite_backward/composite_backward_api.h | 10 +++++++++- test/legacy_test/test_activation_op.py | 6 ++++++ test/legacy_test/test_cast_op.py | 6 ++++++ test/legacy_test/test_elementwise_div_op.py | 3 +-- test/legacy_test/test_mean_op.py | 5 ++++- test/legacy_test/test_softmax_op.py | 7 ++++++- 6 files changed, 32 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 4d2e31ebd4f..692dda588df 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - set_output(out_grad * out, x_grad); + if (out.dtype() == phi::DataType::FLOAT16 || + out.dtype() == phi::DataType::BFLOAT16) { + Tensor out_promote = cast(out, phi::DataType::FLOAT32); + Tensor out_grad_promote = cast(out_grad, phi::DataType::FLOAT32); + set_output(cast(out_promote * out_grad_promote, out.dtype()), + x_grad); + } else { + set_output(out_grad * out, x_grad); + } } } diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 999510e6733..5157b070ad9 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -113,6 +113,7 @@ class TestExpFp32_Prim(OpTest): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} self.if_enable_cinn() + self.convert_input_output() def test_check_output(self): self.check_output() @@ -129,6 +130,9 @@ class TestExpFp32_Prim(OpTest): def if_enable_cinn(self): pass + def convert_input_output(self): + pass + class TestExpFp64_Prim(TestExpFp32_Prim): def init_dtype(self): @@ -4003,6 +4007,7 @@ def create_test_act_fp16_class( create_test_act_fp16_class(TestActivation) +create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) @@ -4133,6 +4138,7 @@ def create_test_act_bf16_class( create_test_act_bf16_class(TestActivation) +create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True) create_test_act_bf16_class(TestExpm1) create_test_act_bf16_class(TestSigmoid, check_prim=True) create_test_act_bf16_class(TestSilu, check_prim=True) diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index baa5bc3d90d..c830f5f9f81 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -108,6 +108,9 @@ class TestCastOpBf16ToFp32(OpTest): self.prim_op_type = "prim" self.python_api = cast_wrapper self.public_python_api = cast_wrapper + self.if_enable_cinn() + + def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): @@ -130,6 +133,9 @@ class TestCastOpFp32ToBf16(OpTest): self.prim_op_type = "prim" self.python_api = cast_wrapper self.public_python_api = cast_wrapper + self.if_enable_cinn() + + def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): diff --git a/test/legacy_test/test_elementwise_div_op.py b/test/legacy_test/test_elementwise_div_op.py index ea8fbac7f09..2432b3b04e4 100644 --- a/test/legacy_test/test_elementwise_div_op.py +++ b/test/legacy_test/test_elementwise_div_op.py @@ -222,9 +222,8 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp): check_args.insert(0, self.place) self.check_grad_with_place(*check_args, **check_kwargs) - # elementwise_pow does't support bfloat16 def if_check_prim(self): - self.check_prim = False + self.check_prim = True def if_enable_cinn(self): self.enable_cinn = False diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 8e48a334ee4..78bcc7a46bc 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -211,7 +211,7 @@ class TestReduceMeanBF16Op(OpTest): self.axis = [0] self.keepdim = False self.set_attrs() - self.enable_cinn = False + self.if_enable_cinn() np.random.seed(10) x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32) @@ -227,6 +227,9 @@ class TestReduceMeanBF16Op(OpTest): 'reduce_all': self.reduce_all, } + def if_enable_cinn(self): + self.enable_cinn = False + def set_attrs(self): pass diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index abf753ef8e0..4374dede00b 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -398,7 +398,9 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): class TestSoftmaxBF16Op(OpTest): def setUp(self): self.op_type = "softmax" + self.prim_op_type = "comp" self.python_api = F.softmax + self.public_python_api = F.softmax self.use_cudnn = self.init_cudnn() self.use_mkldnn = False self.dtype = np.uint16 @@ -424,7 +426,9 @@ class TestSoftmaxBF16Op(OpTest): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_dygraph=(not self.use_mkldnn)) + self.check_output_with_place( + place, check_dygraph=(not self.use_mkldnn), check_prim=True + ) def test_check_grad(self): place = core.CUDAPlace(0) @@ -434,6 +438,7 @@ class TestSoftmaxBF16Op(OpTest): "Out", numeric_grad_delta=0.05, check_dygraph=(not self.use_mkldnn), + check_prim=True, ) -- GitLab