From dd74b3d1859d3349b32f64ff966d4d50c85c81ad Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 11 Apr 2023 15:46:42 +0800 Subject: [PATCH] [prim]use Operator to reconstruct the primitive operator defined in c++ (#51997) --- .../utils/static/composite_grad_desc_maker.h | 3 + paddle/fluid/pybind/protobuf.cc | 3 +- python/paddle/fluid/backward.py | 69 ++++++++++++++----- python/paddle/fluid/framework.py | 35 ++++++++-- .../unittests/prim/test_comp_dispensable.py | 45 ++++++++++++ 5 files changed, 129 insertions(+), 26 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/prim/test_comp_dispensable.py diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index 83b18814b19..b1b24af231f 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -575,6 +575,9 @@ class CompositeGradOpMakerBase { const std::unordered_map& RuntimeAttrs() const { + LOG(WARNING) << "CompositeGradOpMaker doesn't support use runtime attrs, " + "but find the op" + << fwd_op_.Type() << "use runtime attr."; return fwd_op_.GetRuntimeAttrMap(); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 9661d552414..5493cc945cf 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -425,7 +425,8 @@ void BindOpDesc(pybind11::module *m) { &pd::OpDesc::SetDistAttr, pybind11::return_value_policy::reference) .def("inputs", [](pd::OpDesc &self) { return self.Inputs(); }) - .def("outputs", &pd::OpDesc::Outputs); + .def("outputs", &pd::OpDesc::Outputs) + .def("get_attr_map", &pd::OpDesc::GetAttrMap); pybind11::class_ scalar(*m, "Scalar", ""); scalar.def(py::init()) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9a6572db727..46f225e0d09 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1715,35 +1715,68 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def infershape_for_composite(block, grad_op_desc): - # pruning empty output + # NOTE: why pruning the operator with empty output here ? + # Some backward operator will output emtpy var, which will cause infer + # shape error, such assign with input's stop_gradient=True if len(grad_op_desc.output_arg_names()) == 0: return - # append op to block - op_desc = block.desc.append_op() - op_desc.copy_from(grad_op_desc) - op_desc._set_attr( - core.op_proto_and_checker_maker.kOpRoleAttrName(), - core.op_proto_and_checker_maker.OpRole.Backward, - ) - - # create output var + # create output variable new_vars = set() - # create new gradient variables - for grad_var_name in op_desc.output_arg_names(): + for grad_var_name in grad_op_desc.output_arg_names(): if not ( block.desc.has_var_recursive(grad_var_name.encode()) or grad_var_name == core.empty_var_name() ): - block.desc.var(grad_var_name.encode()) + # NOTE: stop_gradient will be set in append_op + desc = block.desc.var(grad_var_name.encode()) + block.create_var(name=grad_var_name, desc=desc, type=desc.type()) new_vars.add(grad_var_name) - # infer shape and infer dthype - op_desc.check_attrs() - op_desc.infer_var_type(block.desc) - op_desc.infer_shape(block.desc) + # NOTE For the primitive operator generated by decompositing phi grad kernel, + # we Operator to reconstruct the op_desc for reusing some complex logic, such + # as processing dispensable input, intermediate output, extra attrs, etc... + if framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()): + op = block.append_op( + type=grad_op_desc.type(), + inputs={ + name: [block._find_var_recursive(arg) for arg in args] + for name, args in grad_op_desc.inputs().items() + }, + outputs={ + name: [block._find_var_recursive(arg) for arg in args] + for name, args in grad_op_desc.outputs().items() + }, + # NOTE Runtime attr will be ignore as the c++ GetRuntimeAttr + # interface cann't be exported to python. Please note the WARNNING + # message logged in RuntimeAttrs of composite_grad_desc_maker.h + attrs=grad_op_desc.get_attr_map(), + ) + op.desc._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.OpRole.Backward, + ) + grad_op_desc.copy_from(op.desc) + # For the backward operator, we reuse the logic of _append_backward_var + else: + op_desc = block.desc.append_op() + op_desc.copy_from(grad_op_desc) + op_desc._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.OpRole.Backward, + ) + op_desc.check_attrs() + op_desc.infer_var_type(block.desc) + op_desc.infer_shape(block.desc) + for arg in op_desc.output_arg_names(): + if arg in new_vars: + _infer_var_data_type_shape_(arg, block) + + grad_op_desc.copy_from(op_desc) - for arg in op_desc.output_arg_names(): + # NOTE: Some operator doesn't infer dtype correctly, this patch set the + # grad_var dtype same with corresponding forward variable. + for arg in grad_op_desc.output_arg_names(): if arg in new_vars: _infer_var_data_type_shape_(arg, block) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 708cc462e78..db17ea36884 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2916,14 +2916,35 @@ class Operator: for m in proto.outputs: if (m.name not in outputs) and m.dispensable: continue - if not ((m.name in outputs) or m.dispensable): - raise ValueError( - ( - "Incorrect setting for output(s) of " - "operator \"%s\", should set: [%s]." + + # FIXME: The outputs of primitive operator currently + # doesn't include intermediate output as it will be dropped + # in operator codegen, such as xshape output of reshape2. + # It will fixed when the operator codegen support + # intermediate output. + if core._is_bwd_prim_enabled(): + if not ( + (m.name in outputs) + or m.dispensable + or m.intermediate + ): + raise ValueError( + ( + "Incorrect setting for output(s) of " + "operator \"%s\", should set: [%s]." + ) + % (type, m.name) ) - % (type, m.name) - ) + else: + if not ((m.name in outputs) or m.dispensable): + raise ValueError( + ( + "Incorrect setting for output(s) of " + "operator \"%s\", should set: [%s]." + ) + % (type, m.name) + ) + for out_proto in proto.outputs: if out_proto.name not in outputs: continue diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_dispensable.py b/python/paddle/fluid/tests/unittests/prim/test_comp_dispensable.py new file mode 100644 index 00000000000..a4f4df5fdd1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_dispensable.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + + +class TestDispensable(unittest.TestCase): + def setUp(self): + paddle.fluid.core._set_prim_all_enabled(True) + + def tearDown(self): + paddle.fluid.core._set_prim_all_enabled(False) + + def test_dispensable(self): + @paddle.jit.to_static + def f(x): + return paddle.split(x, num_or_sections=2) + + f = paddle.jit.to_static(f) + x = paddle.rand((8,)) + x.stop_gradient = False + + op = f.get_concrete_program(x)[1].backward_program.block(0).ops[-1] + self.assertEqual( + op.attr('op_role'), + int(paddle.fluid.core.op_proto_and_checker_maker.OpRole.Backward), + ) + self.assertIn('AxisTensor', op.input_names) + + +if __name__ == '__main__': + unittest.main() -- GitLab