From d0d739cad7befc8156efa277be30416cd46edd1c Mon Sep 17 00:00:00 2001 From: SylarTiaNII <121000916+SylarTiaNII@users.noreply.github.com> Date: Thu, 9 Mar 2023 11:19:30 +0800 Subject: [PATCH] add abs composite backward op (#50963) * add abs composite backward op * add missing changes during merge * modify according to new rules * local UT OK * fix typo * codestyle * register composite operator * add fp16 test for abs * replace experimenta::tensor --- paddle/fluid/operators/abs_op.cc | 22 +++++++++++++++++++ .../composite_backward_api.h | 9 ++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../tests/unittests/test_activation_op.py | 7 ++++-- 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 0bf78f41d6..adc9f20b8c 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -19,6 +19,9 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -92,6 +95,24 @@ class AbsGradMaker : public framework::SingleGradOpMaker { } }; +class AbsCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor input = this->GetSingleForwardInput("X"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor input_grad = this->GetSingleInputGrad("X"); + + auto dx_ptr = this->GetOutputPtr(&input_grad); + std::string dx_name = this->GetOutputName(input_grad); + + VLOG(6) << "Running abs_grad composite func"; + prim::abs_grad(input, out_grad, dx_ptr); + this->RecoverOutputName(input_grad, dx_name); + } +}; + // AbsGrad: dx=dy if x >=0 else -dy // AbsDoubleGrad: ddy = ddx if x >=0 else -ddx template @@ -150,6 +171,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(abs, ops::AbsOp, ops::AbsOpMaker, + ops::AbsCompositeGradOpMaker, ops::AbsGradMaker, ops::AbsGradMaker, AbsInferShapeFunctor); diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 3fa35c034a..823a068dcd 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -453,6 +453,15 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto abs_tmp = abs(x); + auto divide_tmp = divide(x, abs_tmp); + set_output(out_grad * divide_tmp, x_grad); + } +} + template void matmul_double_grad(const Tensor& x, const Tensor& y, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4d98226120..1f73aa1fac 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -17,6 +17,7 @@ param : [x] kernel : func : abs_grad + composite : abs_grad(x, out_grad, x_grad) backward : abs_double_grad - backward_op : add_double_grad diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 6883a76982..b7e76ef816 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1284,6 +1284,9 @@ class TestRsqrt_ZeroDim(TestRsqrt): class TestAbs(TestActivation): def setUp(self): self.op_type = "abs" + self.prim_op_type = "prim" + self.python_api = paddle.abs + self.enable_cinn = False self.init_dtype() self.init_shape() @@ -1305,7 +1308,7 @@ class TestAbs(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_eager=False) + self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) class TestAbs_ZeroDim(TestAbs): @@ -3828,7 +3831,7 @@ create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSqrt) -create_test_act_fp16_class(TestAbs) +create_test_act_fp16_class(TestAbs, check_prim=True) create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False) create_test_act_fp16_class(TestCos, grad_atol=0.85) -- GitLab