diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 2d5ef75d1c6c4570326cbe0637130c1fdb0045c8..5fa5a27ed94a11a55e1394f5915ea9113c2a4d31 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -307,6 +307,9 @@ class OpInfoParser: self.inplace_map = self.parse_op_inplace_info() self.view_map = self.parse_op_view_info() + # parse has_custom_verify + self.custom_verify = self.parse_custom_verify() + def cross_check(self, name_list, type_list, optional_list=None): assert len(name_list) == len( type_list @@ -316,6 +319,11 @@ class OpInfoParser: optional_list ), "type list size != optional list size." + def parse_custom_verify(self): + if 'custom_verify' in self.op_yaml_item: + return self.op_yaml_item['custom_verify'] + return False + def parse_op_phi_name(self): if (self.parse_op_inplace_info() is None) and ( self.parse_op_view_info() is None @@ -980,17 +988,19 @@ def OpGenerator( ) # generate op verify function str - op_verify_str = gen_verify_func_str( - op_class_name, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_output_type_list, - op_output_optional_list, - ) + op_verify_str = '' + if not op_info.custom_verify: + op_verify_str = gen_verify_func_str( + op_class_name, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_output_type_list, + op_output_optional_list, + ) op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) diff --git a/paddle/fluid/ir/dialect/pd_op.yaml b/paddle/fluid/ir/dialect/pd_op.yaml index 8160a636ddcb6a40898085e202122945ff9c1bf6..a343b11f185d05224a6c51331691c7780acb4a39 100644 --- a/paddle/fluid/ir/dialect/pd_op.yaml +++ b/paddle/fluid/ir/dialect/pd_op.yaml @@ -9,6 +9,7 @@ data_transform: null inplace: null backward: null + - name: fetch inputs: - typename: Tensor @@ -150,18 +151,6 @@ no_need_buffer: null data_transform: null -- name: py_func_ - inputs: - - {typename: 'Tensor', name: x, optional: false, no_need_buffer: false, data_transform: {}} - attrs: - - {typename: 'int', name: forward_callable_id, default_value: '0'} - - {typename: 'int', name: backward_callable_id, default_value: '-1'} - - {typename: 'str[]', name: backward_skip_vars, default_value: '{}'} - outputs: - - {typename: 'Tensor', name: out, optional: false, intermediate: false} - no_need_buffer: null - data_transform: null - - name: embedding_grad_sparse inputs: - typename: Tensor