diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 4d2e31ebd4facbc79a02dff7462b94039db4e163..692dda588df567c5ea37ff8bacd84323fd664aab 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - set_output(out_grad * out, x_grad); + if (out.dtype() == phi::DataType::FLOAT16 || + out.dtype() == phi::DataType::BFLOAT16) { + Tensor out_promote = cast(out, phi::DataType::FLOAT32); + Tensor out_grad_promote = cast(out_grad, phi::DataType::FLOAT32); + set_output(cast(out_promote * out_grad_promote, out.dtype()), + x_grad); + } else { + set_output(out_grad * out, x_grad); + } } } diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 999510e6733078243e95969c91bc6336a0a7eefc..5157b070ad94d351e09dc052416b1aa63f7334b6 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -113,6 +113,7 @@ class TestExpFp32_Prim(OpTest): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} self.if_enable_cinn() + self.convert_input_output() def test_check_output(self): self.check_output() @@ -129,6 +130,9 @@ class TestExpFp32_Prim(OpTest): def if_enable_cinn(self): pass + def convert_input_output(self): + pass + class TestExpFp64_Prim(TestExpFp32_Prim): def init_dtype(self): @@ -4003,6 +4007,7 @@ def create_test_act_fp16_class( create_test_act_fp16_class(TestActivation) +create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) @@ -4133,6 +4138,7 @@ def create_test_act_bf16_class( create_test_act_bf16_class(TestActivation) +create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True) create_test_act_bf16_class(TestExpm1) create_test_act_bf16_class(TestSigmoid, check_prim=True) create_test_act_bf16_class(TestSilu, check_prim=True) diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index baa5bc3d90dd348a1270f05b12013861165780e2..c830f5f9f81aae435ce55c7ae8712096418cbfdc 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -108,6 +108,9 @@ class TestCastOpBf16ToFp32(OpTest): self.prim_op_type = "prim" self.python_api = cast_wrapper self.public_python_api = cast_wrapper + self.if_enable_cinn() + + def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): @@ -130,6 +133,9 @@ class TestCastOpFp32ToBf16(OpTest): self.prim_op_type = "prim" self.python_api = cast_wrapper self.public_python_api = cast_wrapper + self.if_enable_cinn() + + def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): diff --git a/test/legacy_test/test_elementwise_div_op.py b/test/legacy_test/test_elementwise_div_op.py index ea8fbac7f0978f84c819ab50a114c49dca88d27a..2432b3b04e4ab81c685cc61d23e1877bddb7a0fb 100644 --- a/test/legacy_test/test_elementwise_div_op.py +++ b/test/legacy_test/test_elementwise_div_op.py @@ -222,9 +222,8 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp): check_args.insert(0, self.place) self.check_grad_with_place(*check_args, **check_kwargs) - # elementwise_pow does't support bfloat16 def if_check_prim(self): - self.check_prim = False + self.check_prim = True def if_enable_cinn(self): self.enable_cinn = False diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 8e48a334ee40c2fc77e74b4b605edd9acd750ce9..78bcc7a46bc9b8e59fea593055a311e1fce4f7f4 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -211,7 +211,7 @@ class TestReduceMeanBF16Op(OpTest): self.axis = [0] self.keepdim = False self.set_attrs() - self.enable_cinn = False + self.if_enable_cinn() np.random.seed(10) x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32) @@ -227,6 +227,9 @@ class TestReduceMeanBF16Op(OpTest): 'reduce_all': self.reduce_all, } + def if_enable_cinn(self): + self.enable_cinn = False + def set_attrs(self): pass diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index abf753ef8e04b7220f5bf5940678a2cc54bb9532..4374dede00b6650b676eaabfaa4d980fae66a667 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -398,7 +398,9 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): class TestSoftmaxBF16Op(OpTest): def setUp(self): self.op_type = "softmax" + self.prim_op_type = "comp" self.python_api = F.softmax + self.public_python_api = F.softmax self.use_cudnn = self.init_cudnn() self.use_mkldnn = False self.dtype = np.uint16 @@ -424,7 +426,9 @@ class TestSoftmaxBF16Op(OpTest): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_dygraph=(not self.use_mkldnn)) + self.check_output_with_place( + place, check_dygraph=(not self.use_mkldnn), check_prim=True + ) def test_check_grad(self): place = core.CUDAPlace(0) @@ -434,6 +438,7 @@ class TestSoftmaxBF16Op(OpTest): "Out", numeric_grad_delta=0.05, check_dygraph=(not self.use_mkldnn), + check_prim=True, )