From 297182f73ed452c95aca15cea82a8732e818e573 Mon Sep 17 00:00:00 2001 From: SylarTiaNII <121000916+SylarTiaNII@users.noreply.github.com> Date: Wed, 15 Mar 2023 19:03:16 +0800 Subject: [PATCH] add assign composite backward op (#51430) * add assign composite backward op * fix log msg * code style * fix comp rule * replace assign with by_pass --- paddle/fluid/operators/assign_op.cc | 22 +++++++++++++++++++ .../composite_backward_api.h | 7 ++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../fluid/tests/unittests/test_assign_op.py | 8 +++++-- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 244b3aec9c9..ebd4fbf491e 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -17,6 +17,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -109,6 +113,23 @@ class AssignGradMaker : public framework::SingleGradOpMaker { } }; +class AssignCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor input_grad = this->GetSingleInputGrad("X"); + + auto dx_ptr = this->GetOutputPtr(&input_grad); + std::string dx_name = this->GetOutputName(input_grad); + + VLOG(6) << "Running assign_grad composite func"; + prim::assign_grad(out_grad, dx_ptr); + this->RecoverOutputName(input_grad, dx_name); + } +}; + DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); } // namespace operators @@ -122,6 +143,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(assign, PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(assign, ops::AssignOp, + ops::AssignCompositeGradOpMaker, ops::AssignGradMaker, ops::AssignGradMaker, ops::AssignOpProtoMaker, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 7afd190069d..6fa1388e469 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -930,6 +930,13 @@ void gather_nd_grad(const Tensor& x, } } +template +void assign_grad(const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + by_pass(out_grad, x_grad); + } +} + template void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 8eb3095933a..8381aeacdd3 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -94,6 +94,7 @@ forward : assign (Tensor x) -> Tensor(out) args : (Tensor out_grad) output : Tensor(x_grad) + composite: assign_grad(out_grad, x_grad) invoke : assign(out_grad) - backward_op : assign_out__grad diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index 113d6a883fb..e53f0b8a119 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -30,6 +30,8 @@ class TestAssignOp(op_test.OpTest): def setUp(self): self.python_api = paddle.assign self.op_type = "assign" + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.random(size=(100, 10)).astype('float64') self.inputs = {'X': x} self.outputs = {'Out': x} @@ -41,7 +43,7 @@ class TestAssignOp(op_test.OpTest): def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) paddle.disable_static() @@ -49,6 +51,8 @@ class TestAssignFP16Op(op_test.OpTest): def setUp(self): self.python_api = paddle.assign self.op_type = "assign" + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.random(size=(100, 10)).astype('float16') self.inputs = {'X': x} self.outputs = {'Out': x} @@ -60,7 +64,7 @@ class TestAssignFP16Op(op_test.OpTest): def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) paddle.disable_static() -- GitLab