diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 0bf78f41d64691176a1c02d2c3b6517d8821fa04..adc9f20b8c93d17a35eeddee48275ba4a23a7737 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -19,6 +19,9 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -92,6 +95,24 @@ class AbsGradMaker : public framework::SingleGradOpMaker { } }; +class AbsCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor input = this->GetSingleForwardInput("X"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor input_grad = this->GetSingleInputGrad("X"); + + auto dx_ptr = this->GetOutputPtr(&input_grad); + std::string dx_name = this->GetOutputName(input_grad); + + VLOG(6) << "Running abs_grad composite func"; + prim::abs_grad(input, out_grad, dx_ptr); + this->RecoverOutputName(input_grad, dx_name); + } +}; + // AbsGrad: dx=dy if x >=0 else -dy // AbsDoubleGrad: ddy = ddx if x >=0 else -ddx template @@ -150,6 +171,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(abs, ops::AbsOp, ops::AbsOpMaker, + ops::AbsCompositeGradOpMaker, ops::AbsGradMaker, ops::AbsGradMaker, AbsInferShapeFunctor); diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 3fa35c034a00d61b28827bfc878cf0838fb7bf24..823a068dcd78b47bfd6d82e6a215877152101db6 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -453,6 +453,15 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto abs_tmp = abs(x); + auto divide_tmp = divide(x, abs_tmp); + set_output(out_grad * divide_tmp, x_grad); + } +} + template void matmul_double_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4d98226120d4d561833949db7598609fd57bb869..1f73aa1fac193d5a7b587b022c2ed666583a6108 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -17,6 +17,7 @@ param : [x] kernel : func : abs_grad + composite : abs_grad(x, out_grad, x_grad) backward : abs_double_grad - backward_op : add_double_grad diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 6883a769829cec7a66848cf521467fac41d84041..b7e76ef816e12b29f344c0dd0e151f3a3bd869a9 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1284,6 +1284,9 @@ class TestRsqrt_ZeroDim(TestRsqrt): class TestAbs(TestActivation): def setUp(self): self.op_type = "abs" + self.prim_op_type = "prim" + self.python_api = paddle.abs + self.enable_cinn = False self.init_dtype() self.init_shape() @@ -1305,7 +1308,7 @@ class TestAbs(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_eager=False) + self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) class TestAbs_ZeroDim(TestAbs): @@ -3828,7 +3831,7 @@ create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSqrt) -create_test_act_fp16_class(TestAbs) +create_test_act_fp16_class(TestAbs, check_prim=True) create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False) create_test_act_fp16_class(TestCos, grad_atol=0.85)