# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging import os import yaml from op_build_gen import gen_build_func_str from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, gen_op_vjp_str, ) from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str from vjp_interface_gen_op_list import ( vjp_interface_declare_gen_op_list, vjp_interface_implementation_gen_op_list, ) # ===================================== # String Template for h file code gen # ===================================== NAMESPACE_GARD_TEMPLATE = """namespace {namespace} {{ {input} }} // namespace {namespace}""" H_FILE_TEMPLATE = """#ifdef GET_OP_LIST #undef GET_OP_LIST {op_declare} #else // This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include #include "paddle/ir/core/builder.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/op_base.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" #include "paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h" #include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" {input} {declare_type_id} #endif """ GET_OP_LIST_TEMPALTE = """{} """ DECLARE_OP_TYPE_ID = """ IR_DECLARE_EXPLICIT_TYPE_ID({op_name}) """ OP_DECLARE_TEMPLATE = """ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ public: using Op::Op; static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; static OpInfoTuple GetOpInfo(); static void Build({build_args}); {build_mutable_attr_is_input} {build_attr_num_over_1} void Verify(); {get_inputs_and_outputs} {exclusive_interface} }}; """ op_0_attribute_declare_str = ( "static constexpr const char **attributes_name = nullptr;" ) op_n_attribute_declare_str = ( "static const char *attributes_name[{attribute_num}];" ) # ===================================== # String Template for cc file code gen # ===================================== CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include "{h_file}" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/ir_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/ir/core/op_base.h" {input} {define_type_id} """ # ===================================== # String Template for pd_op_vjp.cc file code gen # ===================================== VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/op_base.h" #include "paddle/phi/common/int_array.h" namespace paddle {{ namespace dialect {{ {input} }} // namespace dialect }} // namespace paddle """ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ # get op info OP_INFO_TEMPLATE = """ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")""" DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) """ scalar_type_maps = { 'int': 'ir::Int32Attribute', 'int64_t': 'ir::Int64Attribute', 'float': 'ir::FloatAttribute', 'dobule': 'ir::DoubleAttribute', 'bool': 'ir::BoolAttribute', } _NO_NEED_GEN_OPS = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() return phi_fluid_name, phi_fluid_name else: phi_name = names[0].strip() fluid_name = names[1].split(')')[0].strip() return phi_name, fluid_name def to_phi_and_fluid_grad_op_name(op_item): # Templat: sum_grad (reduce_sum_grad), sum_double_grad rtn = [] all_names = op_item.split(', ') for name in all_names: backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name) rtn.append([backward_phi_name, backward_fluid_name]) return rtn # ===================================== # Parse Op Compat From Yaml # ===================================== class OpCompatParser: def __init__(self, ops_compat_yaml_file): self.ops_compat_yaml_file = ops_compat_yaml_file with open(self.ops_compat_yaml_file, "r") as f: self.ops_compat = yaml.safe_load(f) def get_compat(self, op_name): for compat in self.ops_compat: forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name( compat['op'] ) if op_name == forward_phi_name: return compat elif 'backward' in compat.keys(): bkw_names = to_phi_and_fluid_grad_op_name(compat['backward']) for name in bkw_names: if op_name == name[0]: return compat return None # ===================================== # Parse Op Information From Yaml # ===================================== class OpInfoParser: def __init__(self, op_yaml_item, op_compat_item): self.op_yaml_item = op_yaml_item self.op_compat_item = op_compat_item self.op_phi_name = self.parse_op_phi_name() # parse inputs self.input_name_list = self.parse_input_name_list() self.input_type_list = self.parse_input_type_list() self.input_optional_list = self.parse_input_optional_list() self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list() self.cross_check( self.input_name_list, self.input_type_list, self.input_optional_list ) # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() self.output_size_list = self.parse_output_size_list() self.output_optional_list = self.parse_output_optional_list() self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( self.output_name_list, self.output_type_list, self.output_optional_list, ) # parse attributes self.attr_types_map = { 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], 'Scalar(int)': ['ir::Int32Attribute', 'int'], 'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'], 'Scalar(float)': ['ir::FloatAttribute', 'float'], 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'int': ['ir::Int32Attribute', 'int'], 'int32_t': ['ir::Int32Attribute', 'int32_t'], 'int64_t': ['ir::Int64Attribute', 'int64_t'], 'long': ['ir::LongAttribute', 'long'], 'size_t': ['ir::Size_tAttribute', 'size_t'], 'float': ['ir::FloatAttribute', 'float'], 'float[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'double': ['ir::DoubleAttribute', 'double'], 'bool': ['ir::BoolAttribute', 'bool'], 'bool[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'str': ['ir::StrAttribute', 'const std::string&'], 'str[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], 'DataLayout': [ 'paddle::dialect::DataLayoutAttribute', 'DataLayout', ], 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'int64_t[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'int[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], } self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() self.attribute_build_arg_type_list = ( self.parse_attribute_build_arg_type_list() ) self.attribute_gen_arg_type_list = ( self.parse_attribute_gen_arg_type_list() ) self.attribute_data_type_list = self.parse_attribute_data_type_list() self.attribute_default_value_list = ( self.parse_attribute_default_value_list() ) self.cross_check(self.attribute_name_list, self.attribute_type_list) # parse mutable attributes (as inputs) ( self.mutable_attribute_name_list, self.mutable_attribute_type_list, ) = self.parse_mutable_attribute() ( self.non_mutable_attribute_name_list, self.non_mutable_attribute_type_list, self.non_mutable_attribute_data_type_list, self.non_mutable_attribute_build_arg_type_list, self.non_mutable_attribute_default_value_list, ) = self.parse_non_nutable_attribute() # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() self.kernel_map = self.parse_kernel_map() if 'infer_meta' in self.op_yaml_item: self.infer_meta_func = self.op_yaml_item['infer_meta']["func"] else: self.infer_meta_func = None # parse backward name self.backward_name = self.parse_backward_name() # parse inplace && view self.inplace_map = self.parse_op_inplace_info() self.view_map = self.parse_op_view_info() # parse has_custom_verify self.custom_verify = self.parse_custom_verify() def cross_check(self, name_list, type_list, optional_list=None): assert len(name_list) == len( type_list ), "name list size != type list size." if optional_list is not None: assert len(type_list) == len( optional_list ), "type list size != optional list size." def parse_custom_verify(self): if 'custom_verify' in self.op_yaml_item: return self.op_yaml_item['custom_verify'] return False def parse_op_phi_name(self): if (self.parse_op_inplace_info() is None) and ( self.parse_op_view_info() is None ): return [self.op_yaml_item['name']] else: if self.op_yaml_item['name'][-1] == "_": return [self.op_yaml_item['name']] else: return [ self.op_yaml_item['name'], self.op_yaml_item['name'] + "_", ] def parse_op_inplace_info(self): if 'inplace' in self.op_yaml_item: return self.op_yaml_item['inplace'] return None def parse_op_view_info(self): if 'view' in self.op_yaml_item: return self.op_yaml_item['view'] return None def parse_mutable_attribute(self): """ {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'} """ mutable_attribute_name_list = [] mutable_attribute_type_list = [] # scalar if (self.op_compat_item is not None) and ( 'scalar' in self.op_compat_item ): for scalar_attr in self.op_compat_item['scalar'].keys(): if 'data_type' in self.op_compat_item['scalar'][scalar_attr]: if ( scalar_attr == "depth" and self.op_phi_name[0] == "one_hot" ): mutable_attribute_name_list.append("num_classes") else: mutable_attribute_name_list.append(scalar_attr) data_type = self.op_compat_item['scalar'][scalar_attr][ 'data_type' ] # patch for isclose and allclose if (self.op_compat_item['op'] == "isclose") or ( self.op_compat_item['op'] == "allclose" ): data_type = "float" mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", data_type, ] ) # See eye in op_compat.yaml else: mutable_attribute_name_list.append(scalar_attr) mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", self.attribute_data_type_list[ self.attribute_name_list.index(scalar_attr) ], ] ) # int_array if (self.op_compat_item is not None) and ( 'int_array' in self.op_compat_item ): for int_array_attr in self.op_compat_item['int_array']: mutable_attribute_name_list.append(int_array_attr) mutable_attribute_type_list.append( [ "paddle::dialect::IntArrayAttribute", self.op_compat_item['int_array'][int_array_attr][ 'data_type' ], ] ) sorted_mutable_attribute_name_list = [] sorted_mutable_attribute_type_list = [] for attr_name in self.attribute_name_list: if attr_name in mutable_attribute_name_list: sorted_mutable_attribute_name_list.append(attr_name) sorted_mutable_attribute_type_list.append( mutable_attribute_type_list[ mutable_attribute_name_list.index(attr_name) ] ) return ( sorted_mutable_attribute_name_list, sorted_mutable_attribute_type_list, ) def parse_non_nutable_attribute(self): op_non_mutable_attribute_name_list = [] op_non_mutable_attribute_type_list = [] op_non_mutable_attribute_data_type_list = [] op_non_mutable_attribute_build_arg_type_list = [] op_non_mutable_attribute_default_value_list = [] for idx in range(len(self.attribute_name_list)): if ( self.attribute_name_list[idx] not in self.mutable_attribute_name_list ): op_non_mutable_attribute_name_list.append( self.attribute_name_list[idx] ) op_non_mutable_attribute_type_list.append( self.attribute_type_list[idx] ) op_non_mutable_attribute_data_type_list.append( self.attribute_data_type_list[idx] ) op_non_mutable_attribute_build_arg_type_list.append( self.attribute_build_arg_type_list[idx] ) op_non_mutable_attribute_default_value_list.append( self.attribute_default_value_list[idx] ) return ( op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_data_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, ) def parse_input_name_list(self): name_list = [] for input_info in self.op_yaml_item['inputs']: name_list.append(input_info['name']) return name_list def parse_input_type_list(self): input_types_map = { 'Tensor': 'paddle::dialect::DenseTensorType', 'Tensor[]': 'ir::VectorType', } type_list = [] for input_info in self.op_yaml_item['inputs']: assert ( input_info['typename'] in input_types_map ), f"{self.op_phi_name} : Input type error: the input type only support Tensor and Tensor[], but now is {input_info['typename']}." type_list.append(input_types_map[input_info['typename']]) return type_list def parse_input_optional_list(self): optional_list = [] for input_info in self.op_yaml_item['inputs']: if input_info['optional']: optional_list.append("true") else: optional_list.append("false") return optional_list def parse_input_no_need_buffer_list(self): no_need_buffer_list = [] for input_info in self.op_yaml_item['inputs']: if input_info['no_need_buffer']: no_need_buffer_list.append("true") else: no_need_buffer_list.append("false") return no_need_buffer_list def parse_output_name_list(self): name_list = [] for output_info in self.op_yaml_item['outputs']: name_list.append(output_info['name']) return name_list def parse_output_type_list(self): output_type_map = { 'Tensor': 'paddle::dialect::DenseTensorType', 'Tensor[]': 'ir::VectorType', 'SelectedRows': 'paddle::dialect::SelectedRowsType', } type_list = [] for output_info in self.op_yaml_item['outputs']: assert ( output_info['typename'] in output_type_map ), f"{self.op_phi_name} : Output type error: the output type only support Tensor and Tensor[], but now is {output_info['typename']}." type_list.append(output_type_map[output_info['typename']]) return type_list def parse_output_size_list(self): size_list = [] for output_info in self.op_yaml_item['outputs']: if 'size' in output_info: size_list.append(output_info['size']) else: size_list.append(None) return size_list def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: if 'optional' in output_info: if output_info['optional']: optional_list.append("true") else: optional_list.append("false") else: optional_list.append("false") return optional_list def parse_output_intermediate_list(self): intermediate_list = [] for output_info in self.op_yaml_item['outputs']: if 'intermediate' in output_info: if output_info['intermediate']: intermediate_list.append("true") else: intermediate_list.append("false") else: intermediate_list.append("false") return intermediate_list def parse_attribute_name_list(self): name_list = [] for attribute_info in self.op_yaml_item['attrs']: name_list.append(attribute_info['name']) return name_list def parse_attribute_build_arg_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." # Scalar & IntArray has data_type temp_type = self.attr_types_map[attribute_info['typename']][1] if 'Scalar' in temp_type: if 'data_type' in attribute_info: temp_type = attribute_info['data_type'] op_name = self.op_yaml_item['name'] attr_name = attribute_info['name'] if ( op_name not in ["isclose", "allclose"] and self.op_compat_item is not None and 'scalar' in self.op_compat_item.keys() and attr_name in self.op_compat_item['scalar'].keys() and 'data_type' in self.op_compat_item['scalar'][attr_name].keys() ): temp_type = self.op_compat_item['scalar'][attr_name][ 'data_type' ] if 'IntArray' in temp_type: if 'data_type' in attribute_info: temp_type = "const " + attribute_info['data_type'] + "&" type_list.append(self.get_phi_dtype_name(temp_type)) return type_list def parse_attribute_gen_arg_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." temp_type = self.attr_types_map[attribute_info['typename']][1] type_list.append(self.get_phi_dtype_name(temp_type)) return type_list def parse_attribute_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." type_list.append(self.attr_types_map[attribute_info['typename']][0]) return type_list def parse_attribute_data_type_list(self): data_type_list = [] for attribute_info in self.op_yaml_item['attrs']: if 'data_type' in attribute_info: data_type_list.append(attribute_info['data_type']) else: data_type_list.append("") return data_type_list def parse_attribute_default_value_list(self): default_value_list = [] for attribute_info in self.op_yaml_item['attrs']: if 'default_value' in attribute_info: default_value = attribute_info['default_value'] default_value_list.append( self.get_phi_dtype_name(default_value) ) else: default_value_list.append(None) return default_value_list def parse_infer_meta_map(self): if 'infer_meta' in self.op_yaml_item: return self.op_yaml_item['infer_meta'] else: return None def parse_kernel_map(self): if 'kernel' in self.op_yaml_item: return self.op_yaml_item['kernel'] else: return None def parse_backward_name(self): if 'backward' in self.op_yaml_item: return self.op_yaml_item['backward'] else: return None def get_phi_dtype_name(self, name): name = name.replace('Scalar', 'phi::Scalar') name = name.replace('IntArray', 'phi::IntArray') name = name.replace('DataLayout', 'phi::DataLayout') name = name.replace('DataType', 'phi::DataType') if name.startswith( ( "Place", "CPUPlace", "GPUPlace", "GPUPinnedPlace", "XPUPlace", "IPUPlace", "CustomPlace", ) ): return "phi::" + name return name def to_pascal_case(s): words = s.split("_") if s[-1] == "_": return "".join([word.capitalize() for word in words]) + "_" else: return "".join([word.capitalize() for word in words]) + "" def OpGenerator( op_yaml_files, op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, op_def_cc_file, op_vjp_cc_file, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): os.remove(op_def_h_file) if os.path.exists(op_def_cc_file): os.remove(op_def_cc_file) # (2) Prepare: Get all op item in all op_yaml_files op_compat_parser = OpCompatParser(op_compat_yaml_file) op_yaml_items = [] for yaml_file in op_yaml_files: with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops op_info_items = {} for op in op_yaml_items: op_info_items[op['name']] = OpInfoParser( op, op_compat_parser.get_compat(op['name']) ) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list ops_vjp_defined_list = [] # all op vjp static interface defination for key, op_info in op_info_items.items(): # get op inputs info op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list op_input_no_need_buffer_list = op_info.input_no_need_buffer_list # get op outputs info op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list op_output_size_list = op_info.output_size_list op_output_optional_list = op_info.output_optional_list op_output_intermediate_list = op_info.output_intermediate_list # get op mutable attribute op_mutable_attribute_name_list = op_info.mutable_attribute_name_list op_mutable_attribute_type_list = op_info.mutable_attribute_type_list # get op attribute op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list op_attribute_default_value_list = op_info.attribute_default_value_list op_non_mutable_attribute_name_list = ( op_info.non_mutable_attribute_name_list ) op_non_mutable_attribute_type_list = ( op_info.non_mutable_attribute_type_list ) op_non_mutable_attribute_data_type_list = ( op_info.non_mutable_attribute_data_type_list ) op_non_mutable_attribute_build_arg_type_list = ( op_info.non_mutable_attribute_build_arg_type_list ) op_non_mutable_attribute_default_value_list = ( op_info.non_mutable_attribute_default_value_list ) # others op_infer_meta_map = op_info.infer_meta_map op_kernel_map = op_info.kernel_map op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map op_interfaces = ["paddle::dialect::OpYamlInfoInterface"] op_traits = [] if op_info.infer_meta_func: op_interfaces += ["paddle::dialect::InferMetaInterface"] if ( op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: if op_name in _NO_NEED_GEN_OPS: continue op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name # =================================== # # gen interface/trait list str # # =================================== # op_interfaces_str = "" if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) if op_name[-1] == "_": op_traits += ["paddle::dialect::InplaceTrait"] op_traits_str = "" if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) # =================================== # # gen get input/output methods str # # =================================== # op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str( op_input_name_list, op_mutable_attribute_name_list, op_output_name_list, ) # =================================== # # gen Build methods str # # =================================== # build_args_with_muta_attr_not_input_for_declare = "" build_func_with_muta_attr_not_input = "" build_mutable_attr_is_input = "" build_attr_num_over_1 = "" build_func_with_attr_is_map = "" build_func_with_muta_attr_is_input = "" if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, build_func_with_muta_attr_not_input, ) = gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=False, ) if len(op_attribute_name_list) > 0: ( build_args_with_attr_is_map_for_declare, build_func_with_attr_is_map, ) = gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=False, attr_args_is_map=True, ) build_attr_num_over_1 = ( "static void Build({build_args});".format( build_args=build_args_with_attr_is_map_for_declare ) ) if len(op_mutable_attribute_name_list) > 0: ( build_args_with_muta_attr_is_input_for_declare, build_func_with_muta_attr_is_input, ) = gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=True, ) build_mutable_attr_is_input = "static void Build({build_args});".format( build_args=build_args_with_muta_attr_is_input_for_declare ) # gen op_declare_str/op_defined_str if len(op_non_mutable_attribute_name_list) == 0: op_declare_str = OP_DECLARE_TEMPLATE.format( op_name=op_class_name, dialect_op_name=op_dialect_name, interfaces=op_interfaces_str, traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) op_defined_str = "" else: op_declare_str = OP_DECLARE_TEMPLATE.format( op_name=op_class_name, dialect_op_name=op_dialect_name, interfaces=op_interfaces_str, traits=op_traits_str, attribute_declare=op_n_attribute_declare_str.format( attribute_num=len(op_non_mutable_attribute_name_list) ), attribute_num=len(op_non_mutable_attribute_name_list), build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) attribute_names_str = ( '"' + '", "'.join(op_non_mutable_attribute_name_list) + '"' ) op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( op_name=op_class_name, attribute_num=len(op_non_mutable_attribute_name_list), attribute_names=attribute_names_str, ) # =================================== # # gen GetOpInfo func str # # =================================== # # generate get op info funciton: inputs input_info_list = [] for idx in range(len(op_input_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( name=op_input_name_list[idx], typename=op_input_type_list[idx], optional=op_input_optional_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx], is_mutable_attribute='false', ) ) for idx in range(len(op_mutable_attribute_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], typename=op_mutable_attribute_type_list[idx][0], optional='false', no_need_buffer='false', is_mutable_attribute='true', ) ) if len(input_info_list) > 0: inputs_info_str = ", ".join(input_info_list) else: inputs_info_str = "" # generate get op info funciton: outputs outputs_info_str = "" if len(op_output_name_list) > 0: output_info_list = [] for idx in range(len(op_output_name_list)): output_info_list.append( CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( name=op_output_name_list[idx], typename=op_output_type_list[idx], optional=op_output_optional_list[idx], intermediate=op_output_intermediate_list[idx], ) ) outputs_info_str = ", ".join(output_info_list) # generate get op info funciton: attributes attribute_info_str = "" if len(op_non_mutable_attribute_name_list) > 0: attribute_info_list = [] for idx in range(len(op_non_mutable_attribute_name_list)): attribute_info_list.append( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( name=op_non_mutable_attribute_name_list[idx], typename=op_non_mutable_attribute_type_list[idx], data_type=op_non_mutable_attribute_data_type_list[ idx ], ) ) attribute_info_str = ", ".join(attribute_info_list) # generate runtiem info infer_meta_func_str = "" infer_meta_param_str = "" if op_infer_meta_map is not None: infer_meta_func_str = op_infer_meta_map['func'] infer_meta_param_str = '", "'.join(op_infer_meta_map['param']) kernel_func_str = "" kernel_param_str = "" kernel_key_dtype = "" kernel_key_backend = "" if op_kernel_map is not None: kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_param_str = '", "'.join(op_kernel_map['param']) if 'data_type' in op_kernel_map and op_kernel_map['data_type']: kernel_key_dtype = '", "'.join( op_kernel_map['data_type']['candidates'] ) if kernel_key_dtype != "": kernel_key_dtype = '"' + kernel_key_dtype + '"' if 'backend' in op_kernel_map and op_kernel_map['backend']: kernel_key_backend = '", "'.join( op_kernel_map['backend']['candidates'] ) if kernel_key_backend != "": kernel_key_backend = '"' + kernel_key_backend + '"' inplace_str = "" view_str = "" if op_name[-1] == "_": if op_inplace_map is not None: for key, value in op_inplace_map.items(): inplace_str += '{"' + key + '", "' + value + '"},' inplace_str = inplace_str[:-1] if op_view_map is not None: for key, value in op_view_map.items(): view_str += '{"' + key + '", "' + value + '"},' view_str = view_str[:-1] op_info_func_str = OP_INFO_TEMPLATE.format( op_name=op_class_name, inputs=inputs_info_str, attributes=attribute_info_str, outputs=outputs_info_str, infer_meta_func=infer_meta_func_str, infer_meta_param=infer_meta_param_str, kernel_func=kernel_func_str, kernel_param=kernel_param_str, kernel_key_dtype=kernel_key_dtype, kernel_key_backend=kernel_key_backend, inplace=inplace_str, view=view_str, origin_op_name=op_info.op_yaml_item['name'], ) # generate op verify function str op_verify_str = '' if not op_info.custom_verify: op_verify_str = gen_verify_func_str( op_class_name, op_input_type_list, op_input_optional_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_output_type_list, op_output_optional_list, ) op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) # =================================== # # gen Vjp func str # # =================================== # # generate op vjp function str op_vjp_str = '' if dialect_name == "cinn": logging.warning("cinn is currently not support Vjp function") else: # TODO(chenzhiyang) add vjp gen code if ( op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_implementation_gen_op_list ): op_vjp_str = gen_op_vjp_str( op_class_name, op_info.backward_name, op_name, op_info_items[op_info.op_phi_name[0]], op_info_items[op_info.backward_name], ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) ops_defined_list.append(op_info_func_str) ops_defined_list.append(build_func_with_muta_attr_not_input) ops_defined_list.append(build_func_with_attr_is_map) if len(op_mutable_attribute_name_list) > 0: ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": pass else: ops_vjp_defined_list.append(op_vjp_str) # (4) Generate head file str op_namespaces_prev = "" for name in namespaces: op_namespaces_prev += name + "::" ops_name_with_namespace_list = [] for name in ops_name_list: ops_name_with_namespace_list.append(op_namespaces_prev + name) op_list_str = GET_OP_LIST_TEMPALTE.format( ", ".join(ops_name_with_namespace_list) ) # Add GET_OP_LIST declare_type_id_str = "" for op in ops_name_with_namespace_list: declare_type_id_str += DECLARE_OP_TYPE_ID.format(op_name=op) head_file_str = "" head_file_str += "".join(ops_declare_list) # Add op class for name in reversed(namespaces): head_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=head_file_str ) # Add namespaces head_file_str = H_FILE_TEMPLATE.format( op_declare=op_list_str, input=head_file_str, declare_type_id=declare_type_id_str, ) # Add head # (5) Generate source file str source_file_str = "".join(ops_defined_list) # Add op define for name in reversed(namespaces): source_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=source_file_str ) # Add namespaces define_type_id_str = "" for op in ops_name_with_namespace_list: define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op) source_file_str = CC_FILE_TEMPLATE.format( h_file=op_def_h_file[:-4], input=source_file_str, define_type_id=define_type_id_str, ) # Add head vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format( input="".join(ops_vjp_defined_list) ) # (5) Generate pd_op.h.tmp, pd_op.cc.tmp with open(op_def_h_file, 'w') as f: f.write(head_file_str) with open(op_def_cc_file, 'w') as f: f.write(source_file_str) # NOTE(Aurelius84): op_gen.py is called multiply times, # and vjp is only avaible for pd dialect. if dialect_name != 'cinn' and op_vjp_cc_file: with open(op_vjp_cc_file, 'w') as f: f.write(vjp_source_file_str) # ===================================== # Script parameter parsing # ===================================== def ParseArguments(): parser = argparse.ArgumentParser( description='Generate Dialect OP Definition Files By Yaml' ) parser.add_argument('--op_yaml_files', type=str) parser.add_argument('--op_compat_yaml_file', type=str) parser.add_argument('--namespaces', type=str) parser.add_argument('--dialect_name', type=str) parser.add_argument('--op_def_h_file', type=str) parser.add_argument('--op_def_cc_file', type=str) parser.add_argument('--op_vjp_cc_file', type=str) return parser.parse_args() # ===================================== # Main # ===================================== if __name__ == "__main__": # parse arguments args = ParseArguments() op_yaml_files = args.op_yaml_files.split(",") op_compat_yaml_file = args.op_compat_yaml_file namespaces = [] if args.namespaces is not None: namespaces = args.namespaces.split(",") dialect_name = args.dialect_name op_def_h_file = args.op_def_h_file op_def_cc_file = args.op_def_cc_file op_vjp_cc_file = args.op_vjp_cc_file # auto code generate OpGenerator( op_yaml_files, op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, op_def_cc_file, op_vjp_cc_file, )