diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 12a39f82ddea7f8fdab083f54034debd0e025355..8663d23059d45284b469c5d463870b1af1ee91e9 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -722,14 +722,14 @@ def to_pascal_case(s): return "".join([word.capitalize() for word in words]) + "" -def OpInputGradSemanticCheck(op_info, op_info_items): - input_grad_semantic_list = [] +def get_input_grad_semantic(op_info, op_info_items): + input_grad_semantics = [] num_inputs = len(op_info.input_name_list) # get backward op bwd_op_name = op_info.backward_name if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): - input_grad_semantic_list = ["false" for i in range(num_inputs)] + input_grad_semantics = ["false" for i in range(num_inputs)] else: bwd_op_info = op_info_items[bwd_op_name] @@ -747,23 +747,23 @@ def OpInputGradSemanticCheck(op_info, op_info_items): ), "Configuration of forward op and backward op is not match." for i in range(num_inputs): if bwd_fwd_input_list[i] in bwd_output_list_new: - input_grad_semantic_list.append("true") + input_grad_semantics.append("true") else: - input_grad_semantic_list.append("false") + input_grad_semantics.append("false") else: - input_grad_semantic_list = ["false" for i in range(num_inputs)] + input_grad_semantics = ["false" for i in range(num_inputs)] - return input_grad_semantic_list + return input_grad_semantics -def OpMutableAttributeGradSemanticCheck(op_info, op_info_items): - mutable_attribute_grad_semantic_list = [] +def get_mutable_attribute_grad_semantic(op_info, op_info_items): + mutable_attribute_grad_semantics = [] fwd_mutable_attribute_list = op_info.mutable_attribute_name_list # get backward op bwd_op_name = op_info.backward_name if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): - mutable_attribute_grad_semantic_list = [ + mutable_attribute_grad_semantics = [ "false" for i in range(len(fwd_mutable_attribute_list)) ] else: @@ -778,11 +778,11 @@ def OpMutableAttributeGradSemanticCheck(op_info, op_info_items): for i in range(len(fwd_mutable_attribute_list)): if fwd_mutable_attribute_list[i] in bwd_output_list_new: - mutable_attribute_grad_semantic_list.append("true") + mutable_attribute_grad_semantics.append("true") else: - mutable_attribute_grad_semantic_list.append("false") + mutable_attribute_grad_semantics.append("false") - return mutable_attribute_grad_semantic_list + return mutable_attribute_grad_semantics def OpGenerator( @@ -874,12 +874,10 @@ def OpGenerator( exclusive_interface_str = gen_exclusive_interface_str(op_info) # check op inputs and mutable_attributes grad semantics - input_grad_semantic_list = OpInputGradSemanticCheck( + input_grad_semantics = get_input_grad_semantic(op_info, op_info_items) + mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic( op_info, op_info_items ) - mutable_attribute_grad_semantic_list = ( - OpMutableAttributeGradSemanticCheck(op_info, op_info_items) - ) # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: @@ -1059,7 +1057,7 @@ def OpGenerator( optional=op_input_optional_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx], is_mutable_attribute='false', - with_grad_semantic=input_grad_semantic_list[idx], + with_grad_semantic=input_grad_semantics[idx], ) ) for idx in range(len(op_mutable_attribute_name_list)): @@ -1070,7 +1068,7 @@ def OpGenerator( optional='false', no_need_buffer='false', is_mutable_attribute='true', - with_grad_semantic=mutable_attribute_grad_semantic_list[ + with_grad_semantic=mutable_attribute_grad_semantics[ idx ], ) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h index 69a4d848cfe60c8bdfd441790491a4bc759cb7f9..3df6ce5e22c15eda5bffa9eb3af7407b97ac459e 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h @@ -27,7 +27,17 @@ struct OpInputInfo { bool optional = false; bool no_need_buffer = false; bool is_mutable_attribute = false; + /*** + * "with_grad_semantic" represents whether the input of the OP has gradient + * semantics. For example, gather op contains three inputs (x, index, axis), + * but the backward op gather_grad calculates only the gradient with respect + * to x. Therefore, for gather op, only x has gradient semantics. + * The "with_grad_semantic" fields in OpInputInfo for x is true, + * and "with_grad_semantic" fields in OpInputInfo for index and axis are both + * false. + */ bool with_grad_semantic = true; + OpInputInfo() = default; OpInputInfo(const OpInputInfo& input_info) = default; diff --git a/test/ir/test_op_input_grad_semantic.py b/test/ir/test_op_input_grad_semantic.py index f204cff0aeaf367a5f6a679f9a3352ff8b2a1550..7b932245fe51555e5ac5aa0425a235ca05ddbec9 100644 --- a/test/ir/test_op_input_grad_semantic.py +++ b/test/ir/test_op_input_grad_semantic.py @@ -20,7 +20,7 @@ from paddle import ir paddle.enable_static() -def get_ir_program_0(): +def get_gather_program_new_ir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -36,7 +36,7 @@ def get_ir_program_0(): return newir_program -def get_ir_program_1(): +def get_multiply_program_new_ir(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -54,15 +54,17 @@ def get_ir_program_1(): class TestOpInputGradSemantic(unittest.TestCase): - def test_gatherop_input_grad_semantic(self): - newir_program = get_ir_program_0() - op = newir_program.block().ops[-1] - self.assertEqual(op.get_input_grad_semantics(), [True, False, False]) + def test_gather_op_input_grad_semantic(self): + newir_program = get_gather_program_new_ir() + gather_op = newir_program.block().ops[-1] + self.assertEqual( + gather_op.get_input_grad_semantics(), [True, False, False] + ) - def test_multiplyop_input_grad_semantic(self): - newir_program = get_ir_program_1() - op = newir_program.block().ops[-1] - self.assertEqual(op.get_input_grad_semantics(), [True, True]) + def test_multiply_op_input_grad_semantic(self): + newir_program = get_multiply_program_new_ir() + multiply_op = newir_program.block().ops[-1] + self.assertEqual(multiply_op.get_input_grad_semantics(), [True, True]) if __name__ == "__main__":