From d8122a23d5068d21f74c11eb4bb8e0a5b779d567 Mon Sep 17 00:00:00 2001 From: Xianduo Li <30922914+lxd-cumt@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:56:50 +0800 Subject: [PATCH] support for checking op's inputs grad semantic (#56925) * add in OpInputInfo to represent whether an input of OP has grad semantics * add support for check OP's input grad semantic by comparing fwd_op inputs and bwd_op outputs * add pybind interface to support checking OP's inputs grad semantic in python-level * add test * fix bugs * fix bugs in op_gen * fix bugs in op_gen * add test for multiply_op * fix bugs in codestype * fix bugs in codestype --- .../fluid/ir/dialect/op_generator/op_gen.py | 94 ++++++++++++++++++- .../dialect/paddle_dialect/ir/pd_manual_op.cc | 14 ++- .../paddle_dialect/utils/op_yaml_info_util.h | 7 +- paddle/fluid/pybind/ir.cc | 11 +++ .../pattern_rewrite/pattern_rewrite_test.cc | 32 +++++-- test/ir/test_op_input_grad_semantic.py | 69 ++++++++++++++ 6 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 test/ir/test_op_input_grad_semantic.py diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 7ee65d05058..12a39f82dde 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -156,7 +156,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ -CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" +CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute}, {with_grad_semantic})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")""" @@ -345,6 +345,23 @@ class OpInfoParser: # parse has_custom_verify self.custom_verify = self.parse_custom_verify() + # parse forward input name list and attribute name list + self.forward_input_name_list = self.parse_forward_input_name() + + def parse_forward_input_name(self): + if 'forward' in self.op_yaml_item: + forward_input_name_list = [] + forward_map = self.op_yaml_item['forward'] + if forward_map is not None: + inputs = forward_map['inputs'] + for input in inputs: + forward_input_name_list.append(input['name']) + return forward_input_name_list + else: + return None + else: + return None + def cross_check(self, name_list, type_list, optional_list=None): assert len(name_list) == len( type_list @@ -705,6 +722,69 @@ def to_pascal_case(s): return "".join([word.capitalize() for word in words]) + "" +def OpInputGradSemanticCheck(op_info, op_info_items): + input_grad_semantic_list = [] + num_inputs = len(op_info.input_name_list) + + # get backward op + bwd_op_name = op_info.backward_name + if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): + input_grad_semantic_list = ["false" for i in range(num_inputs)] + else: + bwd_op_info = op_info_items[bwd_op_name] + + # cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding input + # thus determine whether each input has grad semantic + bwd_output_list = bwd_op_info.output_name_list + bwd_output_list_new = [] + for bwd_output in bwd_output_list: + bwd_output_list_new.append(bwd_output[:-5]) # cut _grad + + bwd_fwd_input_list = bwd_op_info.forward_input_name_list + if bwd_fwd_input_list is not None: + assert ( + len(bwd_fwd_input_list) == num_inputs + ), "Configuration of forward op and backward op is not match." + for i in range(num_inputs): + if bwd_fwd_input_list[i] in bwd_output_list_new: + input_grad_semantic_list.append("true") + else: + input_grad_semantic_list.append("false") + else: + input_grad_semantic_list = ["false" for i in range(num_inputs)] + + return input_grad_semantic_list + + +def OpMutableAttributeGradSemanticCheck(op_info, op_info_items): + mutable_attribute_grad_semantic_list = [] + fwd_mutable_attribute_list = op_info.mutable_attribute_name_list + + # get backward op + bwd_op_name = op_info.backward_name + if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): + mutable_attribute_grad_semantic_list = [ + "false" for i in range(len(fwd_mutable_attribute_list)) + ] + else: + bwd_op_info = op_info_items[bwd_op_name] + + # cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding attribute + # thus determine whether each attribute has grad semantic + bwd_output_list = bwd_op_info.output_name_list + bwd_output_list_new = [] + for bwd_output in bwd_output_list: + bwd_output_list_new.append(bwd_output[:-5]) + + for i in range(len(fwd_mutable_attribute_list)): + if fwd_mutable_attribute_list[i] in bwd_output_list_new: + mutable_attribute_grad_semantic_list.append("true") + else: + mutable_attribute_grad_semantic_list.append("false") + + return mutable_attribute_grad_semantic_list + + def OpGenerator( op_yaml_files, op_compat_yaml_file, @@ -793,6 +873,14 @@ def OpGenerator( op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) + # check op inputs and mutable_attributes grad semantics + input_grad_semantic_list = OpInputGradSemanticCheck( + op_info, op_info_items + ) + mutable_attribute_grad_semantic_list = ( + OpMutableAttributeGradSemanticCheck(op_info, op_info_items) + ) + # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: if op_name in _NO_NEED_GEN_OPS: @@ -971,6 +1059,7 @@ def OpGenerator( optional=op_input_optional_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx], is_mutable_attribute='false', + with_grad_semantic=input_grad_semantic_list[idx], ) ) for idx in range(len(op_mutable_attribute_name_list)): @@ -981,6 +1070,9 @@ def OpGenerator( optional='false', no_need_buffer='false', is_mutable_attribute='true', + with_grad_semantic=mutable_attribute_grad_semantic_list[ + idx + ], ) ) if len(input_info_list) > 0: diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc index d50faedfd56..d6e1ad52524 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc @@ -35,7 +35,8 @@ OpInfoTuple AddNOp::GetOpInfo() { "ir::VectorType", false, false, - false)}; + false, + true)}; std::vector attributes = {}; std::vector outputs = { OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)}; @@ -416,9 +417,14 @@ OpInfoTuple SplitGradOp::GetOpInfo() { "ir::VectorType", false, false, - false), - OpInputInfo( - "axis", "paddle::dialect::ScalarAttribute", false, false, true)}; + false, + true), + OpInputInfo("axis", + "paddle::dialect::ScalarAttribute", + false, + false, + true, + false)}; std::vector attributes = {}; std::vector outputs = { OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)}; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h index eaa37a3a7de..69a4d848cfe 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h @@ -27,6 +27,7 @@ struct OpInputInfo { bool optional = false; bool no_need_buffer = false; bool is_mutable_attribute = false; + bool with_grad_semantic = true; OpInputInfo() = default; OpInputInfo(const OpInputInfo& input_info) = default; @@ -34,12 +35,14 @@ struct OpInputInfo { const std::string& type_name, bool optional, bool no_need_buffer, - bool is_mutable_attribute) + bool is_mutable_attribute, + bool with_grad_semantic) : name(name), type_name(type_name), optional(optional), no_need_buffer(no_need_buffer), - is_mutable_attribute(is_mutable_attribute) {} + is_mutable_attribute(is_mutable_attribute), + with_grad_semantic(with_grad_semantic) {} }; struct OpOutputInfo { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 2838bfa2fb2..4dc36fe785e 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -247,6 +247,17 @@ void BindOperation(py::module *m) { } return op_list; }) + .def("get_input_grad_semantics", + [](Operation &self) -> py::list { + py::list op_list; + paddle::dialect::OpYamlInfoInterface yaml_interface = + self.dyn_cast(); + auto inputs_grad_info = std::get<0>(yaml_interface.GetOpInfo()); + for (auto &input_grad_info : inputs_grad_info) { + op_list.append(input_grad_info.with_grad_semantic); + } + return op_list; + }) .def("replace_all_uses_with", [](Operation &self, const std::vector &op_results) { self.ReplaceAllUsesWith(op_results); diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index fcca8cde7d5..6dc24d09d78 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -429,14 +429,30 @@ const char *Conv2dFusionOpTest::attributes_name[10] = { // NOLINT OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { std::vector inputs = { - OpInputInfo( - "input", "paddle::dialect::DenseTensorType", false, false, false), - OpInputInfo( - "filter", "paddle::dialect::DenseTensorType", false, false, false), - OpInputInfo( - "bias", "paddle::dialect::DenseTensorType", false, false, false), - OpInputInfo( - "residual", "paddle::dialect::DenseTensorType", true, false, false)}; + OpInputInfo("input", + "paddle::dialect::DenseTensorType", + false, + false, + false, + true), + OpInputInfo("filter", + "paddle::dialect::DenseTensorType", + false, + false, + false, + true), + OpInputInfo("bias", + "paddle::dialect::DenseTensorType", + false, + false, + false, + true), + OpInputInfo("residual", + "paddle::dialect::DenseTensorType", + true, + false, + false, + true)}; std::vector attributes = { OpAttributeInfo("strides", "ir::ArrayAttribute", ""), OpAttributeInfo( diff --git a/test/ir/test_op_input_grad_semantic.py b/test/ir/test_op_input_grad_semantic.py new file mode 100644 index 00000000000..f204cff0aea --- /dev/null +++ b/test/ir/test_op_input_grad_semantic.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir + +paddle.enable_static() + + +def get_ir_program_0(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant( + shape=[3, 4], dtype='float32', value=2.0 + ) + index = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1.0) + axis = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=2.0) + out = paddle.gather(x, index, axis) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +def get_ir_program_1(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant( + shape=[3, 4], dtype='float32', value=2.0 + ) + y = paddle.tensor.fill_constant( + shape=[3, 4], dtype='float32', value=3.0 + ) + out = paddle.multiply(x, y) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestOpInputGradSemantic(unittest.TestCase): + def test_gatherop_input_grad_semantic(self): + newir_program = get_ir_program_0() + op = newir_program.block().ops[-1] + self.assertEqual(op.get_input_grad_semantics(), [True, False, False]) + + def test_multiplyop_input_grad_semantic(self): + newir_program = get_ir_program_1() + op = newir_program.block().ops[-1] + self.assertEqual(op.get_input_grad_semantics(), [True, True]) + + +if __name__ == "__main__": + unittest.main() -- GitLab