diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 823a068dcd78b47bfd6d82e6a215877152101db6..2fe33c92ff8c32284d8ae6bc0b79c932b249e422 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -14,6 +14,12 @@ #pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include + #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/int_array.h" @@ -881,5 +887,16 @@ void gather_nd_grad(const Tensor& x, } } +template +void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto m_2_sqrt_pi = full(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype()); + auto neg_one = full(phi::vectorize(x.dims()), -1.0, x.dtype()); + auto neg_tmp = neg_one * x * x; + auto mul_tmp = m_2_sqrt_pi * exp(neg_tmp); + set_output(out_grad * mul_tmp, x_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 4a42799764fa8014d06d0d4b6e0c84029bad258c..6108fbb5081ab6f5829e66bab8bbab5b92a2bf97 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -423,6 +423,7 @@ kernel : func : erf_grad data_type : out_grad + composite : erf_grad(x, out_grad, x_grad) - backward_op : erfinv_grad forward : erfinv (Tensor x) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/test_erf_op.py b/python/paddle/fluid/tests/unittests/test_erf_op.py index db5c48151c5052e83be61912ad7368b734c79aba..0d78457b1452a52632d99158ff59560dd83a7690 100644 --- a/python/paddle/fluid/tests/unittests/test_erf_op.py +++ b/python/paddle/fluid/tests/unittests/test_erf_op.py @@ -26,6 +26,8 @@ import paddle.fluid.dygraph as dg class TestErfOp(OpTest): def setUp(self): self.op_type = "erf" + self.prim_op_type = "prim" + self.enable_cinn = True self.python_api = paddle.erf self.dtype = self._init_dtype() self.x_shape = [11, 17] @@ -41,7 +43,7 @@ class TestErfOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_prim=True) class TestErfLayer(unittest.TestCase):