# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import yaml from op_interface_gen import gen_exclusive_interface_str, gen_op_infer_meta_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str # ===================================== # String Template for h file code gen # ===================================== NAMESPACE_GARD_TEMPLATE = """namespace {namespace} {{ {input} }} // namespace {namespace}""" H_FILE_TEMPLATE = """#ifdef GET_OP_LIST #undef GET_OP_LIST {op_declare} #else // This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include #include "paddle/ir/core/builder.h" #include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/op_base.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/op_yaml_info_util.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/infermeta.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" {input} {declare_type_id} #endif """ GET_OP_LIST_TEMPALTE = """{} """ DECLARE_OP_TYPE_ID = """ IR_DECLARE_EXPLICIT_TYPE_ID({op_name}) """ OP_DECLARE_TEMPLATE = """ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ public: using Op::Op; static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; static OpInfoTuple GetOpInfo(); static void Build({build_args}); {build_mutable_attr_is_input} void Verify(); {get_inputs_and_outputs} {exclusive_interface} }}; """ op_0_attribute_declare_str = ( "static constexpr const char **attributes_name = nullptr;" ) op_n_attribute_declare_str = ( "static const char *attributes_name[{attribute_num}];" ) # ===================================== # String Template for cc file code gen # ===================================== CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/ir_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/api/lib/utils/allocator.h" {input} {define_type_id} """ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ # get op info OP_INFO_TEMPLATE = """ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info); }} """ CONSTRUCT_INPUT_INFO_TEMPLATE = """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" ) CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" ) # build OP_BUILD_TEMPLATE = """ void {op_name}::Build({build_args}) {{ {build_mutable_attributes} {build_inputs} {build_attributes} {build_outputs} }} """ DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) """ def to_phi_and_fluid_op_name(op_item): # Templat: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() return phi_fluid_name, phi_fluid_name else: phi_name = names[0].strip() fluid_name = names[1].split(')')[0].strip() return phi_name, fluid_name scalar_type_maps = { 'int': 'ir::Int32Attribute', 'int64_t': 'ir::Int64Attribute', 'float': 'ir::FloatAttribute', 'dobule': 'ir::DoubleAttribute', 'bool': 'ir::BoolAttribute', } # ===================================== # Parse Op Compat From Yaml # ===================================== class OpCompatParser: def __init__(self, ops_compat_yaml_file): self.ops_compat_yaml_file = ops_compat_yaml_file with open(self.ops_compat_yaml_file, "r") as f: self.ops_compat = yaml.safe_load(f) def get_compat(self, op_name): for compat in self.ops_compat: phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) if op_name == phi_name: return compat return None # ===================================== # Parse Op Information From Yaml # ===================================== class OpInfoParser: def __init__(self, op_yaml_item, op_compat_item): self.op_yaml_item = op_yaml_item self.op_compat_item = op_compat_item self.op_phi_name = self.parse_op_phi_name() # parse inputs self.input_name_list = self.parse_input_name_list() self.input_type_list = self.parse_input_type_list() self.input_optional_list = self.parse_input_optional_list() self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list() self.cross_check( self.input_name_list, self.input_type_list, self.input_optional_list ) # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() self.output_size_list = self.parse_output_size_list() self.output_optional_list = self.parse_output_optional_list() self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( self.output_name_list, self.output_type_list, self.output_optional_list, ) # parse attributes self.attr_types_map = { 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], 'Scalar(int)': ['ir::Int32Attribute', 'int'], 'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'], 'Scalar(float)': ['ir::FloatAttribute', 'float'], 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'int': ['ir::Int32Attribute', 'int'], 'int32_t': ['ir::Int32Attribute', 'int32_t'], 'int64_t': ['ir::Int64Attribute', 'int64_t'], 'long': ['ir::LongAttribute', 'long'], 'size_t': ['ir::Size_tAttribute', 'size_t'], 'float': ['ir::FloatAttribute', 'float'], 'float[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'double': ['ir::DoubleAttribute', 'double'], 'bool': ['ir::BoolAttribute', 'bool'], 'bool[]': [ 'ir::ArrayAttribute', 'const std::vecot&', ], 'str': ['ir::StrAttribute', 'std::string'], 'str[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'Place': ['paddle::dialect::PlaceAttribute', 'Place'], 'DataLayout': [ 'paddle::dialect::DataLayoutAttribute', 'DataLayout', ], 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'int64_t[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], 'int[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], } self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() self.attribute_build_arg_type_list = ( self.parse_attribute_build_arg_type_list() ) self.attribute_data_type_list = self.parse_attribute_data_type_list() self.attribute_default_value_list = ( self.parse_attribute_default_value_list() ) self.cross_check(self.attribute_name_list, self.attribute_type_list) # parse mutable attributes (as inputs) ( self.mutable_attribute_name_list, self.mutable_attribute_type_list, ) = self.parse_mutable_attribute() ( self.non_mutable_attribute_name_list, self.non_mutable_attribute_type_list, self.non_mutable_attribute_data_type_list, self.non_mutable_attribute_build_arg_type_list, self.non_mutable_attribute_default_value_list, ) = self.parse_non_nutable_attribute() # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() self.kernel_map = self.parse_kernel_map() if 'infer_meta' in self.op_yaml_item: self.infer_meta_func = self.op_yaml_item['infer_meta']["func"] else: self.infer_meta_func = None # parse inplace && view self.inplace_map = self.parse_op_inplace_info() self.view_map = self.parse_op_view_info() def cross_check(self, name_list, type_list, optional_list=None): assert len(name_list) == len( type_list ), "name list size != type list size." if optional_list is not None: assert len(type_list) == len( optional_list ), "type list size != optional list size." def parse_op_phi_name(self): if (self.parse_op_inplace_info() is None) and ( self.parse_op_view_info() is None ): return [self.op_yaml_item['name']] else: if self.op_yaml_item['name'][-1] == "_": return [self.op_yaml_item['name']] else: return [ self.op_yaml_item['name'], self.op_yaml_item['name'] + "_", ] def parse_op_inplace_info(self): if 'inplace' in self.op_yaml_item: return self.op_yaml_item['inplace'] return None def parse_op_view_info(self): if 'view' in self.op_yaml_item: return self.op_yaml_item['view'] return None def parse_mutable_attribute(self): """ {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'} """ mutable_attribute_name_list = [] mutable_attribute_type_list = [] # scalar if (self.op_compat_item is not None) and ( 'scalar' in self.op_compat_item ): for scalar_attr in self.op_compat_item['scalar'].keys(): if 'data_type' in self.op_compat_item['scalar'][scalar_attr]: if ( self.op_compat_item['scalar'][scalar_attr]['data_type'] == "std::string" ): # see isclose and allclose in op_compat.yaml mutable_attribute_name_list.append(scalar_attr) mutable_attribute_type_list.append( ["ir::StrAttribute", "std::string"] ) else: if ( scalar_attr == "depth" and self.op_phi_name[0] == "one_hot" ): mutable_attribute_name_list.append("num_classes") else: mutable_attribute_name_list.append(scalar_attr) mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", self.op_compat_item['scalar'][scalar_attr][ 'data_type' ], ] ) # See eye in op_compat.yaml else: mutable_attribute_name_list.append(scalar_attr) mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", self.attribute_data_type_list[ self.attribute_name_list.index(scalar_attr) ], ] ) # int_array if (self.op_compat_item is not None) and ( 'int_array' in self.op_compat_item ): for int_array_attr in self.op_compat_item['int_array']: mutable_attribute_name_list.append(int_array_attr) mutable_attribute_type_list.append( [ "paddle::dialect::IntArrayAttribute", self.op_compat_item['int_array'][int_array_attr][ 'data_type' ], ] ) sorted_mutable_attribute_name_list = [] sorted_mutable_attribute_type_list = [] for attr_name in self.attribute_name_list: if attr_name in mutable_attribute_name_list: sorted_mutable_attribute_name_list.append(attr_name) sorted_mutable_attribute_type_list.append( mutable_attribute_type_list[ mutable_attribute_name_list.index(attr_name) ] ) return ( sorted_mutable_attribute_name_list, sorted_mutable_attribute_type_list, ) def parse_non_nutable_attribute(self): op_non_mutable_attribute_name_list = [] op_non_mutable_attribute_type_list = [] op_non_mutable_attribute_data_type_list = [] op_non_mutable_attribute_build_arg_type_list = [] op_non_mutable_attribute_default_value_list = [] for idx in range(len(self.attribute_name_list)): if ( self.attribute_name_list[idx] not in self.mutable_attribute_name_list ): op_non_mutable_attribute_name_list.append( self.attribute_name_list[idx] ) op_non_mutable_attribute_type_list.append( self.attribute_type_list[idx] ) op_non_mutable_attribute_data_type_list.append( self.attribute_data_type_list[idx] ) op_non_mutable_attribute_build_arg_type_list.append( self.attribute_build_arg_type_list[idx] ) op_non_mutable_attribute_default_value_list.append( self.attribute_default_value_list[idx] ) return ( op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_data_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, ) def parse_input_name_list(self): name_list = [] for input_info in self.op_yaml_item['inputs']: name_list.append(input_info['name']) return name_list def parse_input_type_list(self): input_types_map = { 'Tensor': 'paddle::dialect::DenseTensorType', 'Tensor[]': 'ir::VectorType', } type_list = [] for input_info in self.op_yaml_item['inputs']: assert ( input_info['typename'] in input_types_map ), f"{self.op_phi_name} : Input type error: the input type only support Tensor and Tensor[], but now is {input_info['typename']}." type_list.append(input_types_map[input_info['typename']]) return type_list def parse_input_optional_list(self): optional_list = [] for input_info in self.op_yaml_item['inputs']: if input_info['optional']: optional_list.append("true") else: optional_list.append("false") return optional_list def parse_input_no_need_buffer_list(self): no_need_buffer_list = [] for input_info in self.op_yaml_item['inputs']: if input_info['no_need_buffer']: no_need_buffer_list.append("true") else: no_need_buffer_list.append("false") return no_need_buffer_list def parse_output_name_list(self): name_list = [] for output_info in self.op_yaml_item['outputs']: name_list.append(output_info['name']) return name_list def parse_output_type_list(self): output_type_map = { 'Tensor': 'paddle::dialect::DenseTensorType', 'Tensor[]': 'ir::VectorType', } type_list = [] for output_info in self.op_yaml_item['outputs']: assert ( output_info['typename'] in output_type_map ), f"{self.op_phi_name} : Output type error: the output type only support Tensor and Tensor[], but now is {output_info['typename']}." type_list.append(output_type_map[output_info['typename']]) return type_list def parse_output_size_list(self): size_list = [] for output_info in self.op_yaml_item['outputs']: if 'size' in output_info: size_list.append(output_info['size']) else: size_list.append(None) return size_list def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: if 'optional' in output_info: if output_info['optional']: optional_list.append("true") else: optional_list.append("false") else: optional_list.append("false") return optional_list def parse_output_intermediate_list(self): intermediate_list = [] for output_info in self.op_yaml_item['outputs']: if 'intermediate' in output_info: if output_info['intermediate']: intermediate_list.append("true") else: intermediate_list.append("false") else: intermediate_list.append("false") return intermediate_list def parse_attribute_name_list(self): name_list = [] for attribute_info in self.op_yaml_item['attrs']: name_list.append(attribute_info['name']) return name_list def parse_attribute_build_arg_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." # Scalar & IntArray has data_type temp_type = self.attr_types_map[attribute_info['typename']][1] if 'Scalar' in temp_type: if 'data_type' in attribute_info: temp_type = attribute_info['data_type'] if 'IntArray' in temp_type: if 'data_type' in attribute_info: temp_type = attribute_info['data_type'] type_list.append(self.get_phi_dtype_name(temp_type)) return type_list def parse_attribute_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." type_list.append(self.attr_types_map[attribute_info['typename']][0]) return type_list def parse_attribute_data_type_list(self): data_type_list = [] for attribute_info in self.op_yaml_item['attrs']: if 'data_type' in attribute_info: data_type_list.append(attribute_info['data_type']) else: data_type_list.append("") return data_type_list def parse_attribute_default_value_list(self): default_value_list = [] for attribute_info in self.op_yaml_item['attrs']: if 'default_value' in attribute_info: default_value = attribute_info['default_value'] default_value_list.append( self.get_phi_dtype_name(default_value) ) else: default_value_list.append(None) return default_value_list def parse_infer_meta_map(self): if 'infer_meta' in self.op_yaml_item: return self.op_yaml_item['infer_meta'] else: return None def parse_kernel_map(self): if 'kernel' in self.op_yaml_item: return self.op_yaml_item['kernel'] else: return None def get_phi_dtype_name(self, name): name = name.replace('Scalar', 'phi::Scalar') name = name.replace('IntArray', 'phi::IntArray') name = name.replace('DataLayout', 'phi::DataLayout') name = name.replace('DataType', 'phi::DataType') if name.startswith( ( "Place", "CPUPlace", "GPUPlace", "GPUPinnedPlace", "XPUPlace", "IPUPlace", "CustomPlace", ) ): return "phi::" + name return name def to_pascal_case(s): words = s.split("_") if s[-1] == "_": return "".join([word.capitalize() for word in words]) + "_" else: return "".join([word.capitalize() for word in words]) + "" # ===================================== # Generate Op Definition Files # ===================================== def GenBuildInputArgsStr( op_input_name_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, for_func_define=True, mutable_attr_is_input=False, ): ''' Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} ''' # add inputs build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" if len(op_input_name_list) > 0: for input_name in op_input_name_list: build_args_str += ", ir::OpResult " + input_name + "_" if not mutable_attr_is_input: # add attributes for attr_idx in range(len(op_attribute_name_list)): build_args_str += ( ", " + op_attribute_build_arg_type_list[attr_idx] + " " + op_attribute_name_list[attr_idx] ) if for_func_define: if op_attribute_default_value_list[attr_idx] is not None: default_value = op_attribute_default_value_list[attr_idx] if ( op_attribute_build_arg_type_list[attr_idx] != "std::string" ): if default_value[0] == "'" or default_value[0] == '"': default_value = default_value[1:] if default_value[-1] == "'" or default_value[-1] == '"': default_value = default_value[0:-1] build_args_str += "=" + default_value else: # add mutable attributes as inputs if len(op_mutable_attribute_name_list) > 0: for mutable_attr in op_mutable_attribute_name_list: build_args_str += ", ir::OpResult " + mutable_attr + "_" # add non-mutable attributes for attr_idx in range(len(op_non_mutable_attribute_name_list)): build_args_str += ( ", " + op_non_mutable_attribute_build_arg_type_list[attr_idx] + " " + op_non_mutable_attribute_name_list[attr_idx] ) if for_func_define: if ( op_non_mutable_attribute_default_value_list[attr_idx] is not None ): default_value = op_non_mutable_attribute_default_value_list[ attr_idx ] if ( op_non_mutable_attribute_build_arg_type_list[attr_idx] != "std::string" ): if default_value[0] == "'" or default_value[0] == '"': default_value = default_value[1:] if default_value[-1] == "'" or default_value[-1] == '"': default_value = default_value[0:-1] build_args_str += "=" + default_value return build_args_str mutable_attribute_phi_type_maps = { 'int': 'phi::DataType::INT32', 'int64_t': 'phi::DataType::INT64', 'float': 'phi::DataType::FLOAT32', 'std::vector': 'phi::DataType::INT64', 'const std::vector&': 'phi::DataType::INT64', } def GenBuildInserFullForMutableAttribute( op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, ): build_mutable_attribute = "" BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build({attr_name}, {phi_dtype}, phi::CPUPlace()); ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); """ BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} paddle::dialect::FullOp full_{attr_name}_op = builder.Build(std::vector{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); """ for idx in range(len(op_mutable_attribute_name_list)): attr_name = op_mutable_attribute_name_list[idx] attr_type = op_mutable_attribute_type_list[idx][0] if attr_name in op_attribute_name_list: phi_dtype = mutable_attribute_phi_type_maps[ op_attribute_build_arg_type_list[ op_attribute_name_list.index(attr_name) ] ] else: phi_dtype = mutable_attribute_phi_type_maps[ op_mutable_attribute_type_list[idx][1] ] if attr_type == "paddle::dialect::IntArrayAttribute": build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format( attr_name=attr_name, phi_dtype=phi_dtype ) else: build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format( attr_name=attr_name, phi_dtype=phi_dtype ) return build_mutable_attribute def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); """ build_input_str = ' VLOG(4) << "Builder construction inputs";\n' input_name_list = op_input_name_list + op_mutable_attribute_name_list if len(input_name_list) > 0: inputs_args_str = "" inputs_args_str += "_, ".join(input_name_list) + "_" build_input_str += BUILD_INPUT_TEMPLATE.format( inputs_args=inputs_args_str ) return build_input_str def GenBuildAttributes( op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list ): INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); """ SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrAttribute({attr}, ir::IrContext::Instance()); """ STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); """ ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; for (size_t i = 0; i < static_cast({attr_size}); i++) {{ {create_attribute} vec_{attr_name}.push_back(attr_{attr_name}); }} ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); """ attr_str = ' VLOG(4) << "Builder construction attributes";\n' for idx in range(len(op_non_mutable_attribute_name_list)): if "ir::ArrayAttribute<" in op_non_mutable_attribute_type_list[idx]: inner_attribute_type = op_non_mutable_attribute_type_list[idx][ 19:-1 ] if inner_attribute_type == "paddle::dialect::IntArrayAttribute": attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], attr_size=op_non_mutable_attribute_name_list[idx] + ".size()", create_attribute=INTARRAY_STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], op_attribute_type=inner_attribute_type, attr=op_non_mutable_attribute_name_list[idx] + "[i]", ), ) elif inner_attribute_type == "paddle::dialect::ScalarAttribute": attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], attr_size=op_non_mutable_attribute_name_list[idx] + ".size()", create_attribute=SCALAR_STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], attr=op_non_mutable_attribute_name_list[idx] + "[i]", ), ) else: attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], attr_size=op_non_mutable_attribute_name_list[idx] + ".size()", create_attribute=STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], op_attribute_type=inner_attribute_type, attr=op_non_mutable_attribute_name_list[idx] + "[i]", ), ) elif ( op_non_mutable_attribute_type_list[idx] == "paddle::dialect::IntArrayAttribute" ): attr_str += INTARRAY_STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], op_attribute_type=op_non_mutable_attribute_type_list[idx], attr=op_non_mutable_attribute_name_list[idx], ) elif ( op_non_mutable_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute" ): attr_str += SCALAR_STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], attr=op_non_mutable_attribute_name_list[idx], ) else: attr_str += STR_TEMPLATE.format( attr_name=op_non_mutable_attribute_name_list[idx], op_attribute_type=op_non_mutable_attribute_type_list[idx], attr=op_non_mutable_attribute_name_list[idx], ) attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format( attr_name=op_non_mutable_attribute_name_list[idx] ) return attr_str def GenBuildOutputs( op_input_name_list, op_input_type_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, mutable_attr_is_input=False, ): build_output_str = ' VLOG(4) << "Builder construction outputs";\n' CREATE_INPUT_METATENSOR_TEMPLATE = """ VLOG(4) << "Builder construction dense_{name}"; phi::DenseTensor dense_{name}(std::make_unique(paddle::platform::CPUPlace()).get(), phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), {name}.dims(), {name}.data_layout(), {name}.lod(), {name}.offset())); VLOG(4) << "Builder construction meta_{name}"; phi::MetaTensor meta_{name}(&dense_{name}); """ CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}; for (size_t i=0; i < static_cast({name}.size()); i++) {{ vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique(paddle::platform::CPUPlace()).get(), phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), {name}[i].dyn_cast().dims(), {name}[i].dyn_cast().data_layout(), {name}[i].dyn_cast().lod(), {name}[i].dyn_cast().offset()))); }} std::vector vec_meta_{name}; for (size_t i=0; i < vec_dense_{name}.size(); i++) {{ vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); }} std::vector meta_{name}; for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ meta_{name}.push_back(&vec_meta_{name}[i]); }} """ CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; phi::MetaTensor meta_{name}(&dense_{name}); """ CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}(({output_size}), phi::DenseTensor()); std::vector vec_meta_{name}; for (size_t i=0; i < static_cast({output_size}); i++) {{ vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); }} std::vector meta_{name}; for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ meta_{name}.push_back(&vec_meta_{name}[i]); }} """ # Prepar input type for idx in range(len(op_input_name_list)): # is a vector if 'ir::VectorType' in op_input_type_list[idx]: build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( name=op_input_name_list[idx] ) # is a Tensor else: build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( name=op_input_name_list[idx] ) # Prepare mutable attributes if mutable_attr_is_input: for idx in range(len(op_mutable_attribute_name_list)): attr_dtype = op_mutable_attribute_type_list[idx] # int_array if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": build_output_str += ( CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) ) # scalar elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": build_output_str += ( CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1], ) ) # string elif attr_dtype[0] == "ir::StrAttribute": build_output_str += "" else: assert "mutable attribtue type is not right." build_output_str += "\n" # Prepare inputs_meta_tensor & attributes for infer meta infer_meta_args = [] for idx in range(len(op_infer_meta_map['param'])): # is input if op_infer_meta_map['param'][idx] in op_input_name_list: if ( "meta_" + op_infer_meta_map['param'][idx] ) not in infer_meta_args: # is a vector if ( 'ir::VectorType' in op_input_type_list[ op_input_name_list.index( op_infer_meta_map['param'][idx] ) ] ): build_output_str += ( CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( name=op_infer_meta_map['param'][idx] ) ) # is a Tensor else: build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format( name=op_infer_meta_map['param'][idx] ) infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) # is attribute else: infer_meta_args.append(op_infer_meta_map['param'][idx]) # Prepare outputs_meta_tensor for infer meta for idx in range(len(op_output_name_list)): # is a vector if 'ir::VectorType' in op_output_type_list[idx]: build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( name=op_output_name_list[idx], output_size=op_output_size_list[idx], ) infer_meta_args.append(f"meta_{op_output_name_list[idx]}") # is a Tensor else: build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format( name=op_output_name_list[idx] ) infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") # Execute infer meta function CREATE_INFER_META_FUNC_TEMPLATE = """ phi::{func}({args}); """ build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) ) # use dense_{name} or vec_dense_{name} to create Outputs type build_output_str += "\n std::vector argument_outputs;" CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); argument_outputs.push_back({name}_dense_tensor_type); """ CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ std::vector {name}_types; for (size_t i=0; i < static_cast({output_size}); i++) {{ {name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); }} ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types); argument_outputs.push_back({name}_vector_type); """ for idx in range(len(op_output_name_list)): # is a vector if 'ir::VectorType' in op_output_type_list[idx]: build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( name=op_output_name_list[idx], output_size=op_output_size_list[idx], ) # is a Tensor else: build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( name=op_output_name_list[idx] ) build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" return build_output_str def GenBuild( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=False, ): build_args_for_declare = "" build_func = "" build_args_for_declare = GenBuildInputArgsStr( op_input_name_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, True, muta_attr_is_input, ) build_args_for_define = GenBuildInputArgsStr( op_input_name_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, False, muta_attr_is_input, ) inset_full_for_mutable_attributes_str = "" if not muta_attr_is_input: inset_full_for_mutable_attributes_str = ( GenBuildInserFullForMutableAttribute( op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, ) ) build_inputs_str = GenBuildInputs( op_input_name_list, op_mutable_attribute_name_list ) build_attributes_str = GenBuildAttributes( op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, ) build_outputs_str = GenBuildOutputs( op_input_name_list, op_input_type_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input, ) build_func = OP_BUILD_TEMPLATE.format( op_name=op_class_name, build_args=build_args_for_define, build_mutable_attributes=inset_full_for_mutable_attributes_str, build_inputs=build_inputs_str, build_attributes=build_attributes_str, build_outputs=build_outputs_str, ) return (build_args_for_declare, build_func) def OpGenerator( op_yaml_files, op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, op_def_cc_file, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): os.remove(op_def_h_file) if os.path.exists(op_def_cc_file): os.remove(op_def_cc_file) # (2) Prepare: Get all op item in all op_yaml_files op_compat_parser = OpCompatParser(op_compat_yaml_file) op_yaml_items = [] for yaml_file in op_yaml_files: with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops op_info_items = [] for op in op_yaml_items: op_info_items.append( OpInfoParser(op, op_compat_parser.get_compat(op['name'])) ) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list for op_info in op_info_items: # get op inputs info op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list op_input_no_need_buffer_list = op_info.input_no_need_buffer_list # get op outputs info op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list op_output_size_list = op_info.output_size_list op_output_optional_list = op_info.output_optional_list op_output_intermediate_list = op_info.output_intermediate_list # get op mutable attribute op_mutable_attribute_name_list = op_info.mutable_attribute_name_list op_mutable_attribute_type_list = op_info.mutable_attribute_type_list # get op attribute op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list op_attribute_default_value_list = op_info.attribute_default_value_list op_non_mutable_attribute_name_list = ( op_info.non_mutable_attribute_name_list ) op_non_mutable_attribute_type_list = ( op_info.non_mutable_attribute_type_list ) op_non_mutable_attribute_data_type_list = ( op_info.non_mutable_attribute_data_type_list ) op_non_mutable_attribute_build_arg_type_list = ( op_info.non_mutable_attribute_build_arg_type_list ) op_non_mutable_attribute_default_value_list = ( op_info.non_mutable_attribute_default_value_list ) # others op_infer_meta_map = op_info.infer_meta_map op_kernel_map = op_info.kernel_map op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map op_interfaces = ["OpYamlInfoInterface"] op_traits = [] if op_info.infer_meta_func: op_interfaces += ["InferMetaInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name # =================================== # # gen interface/trait list str # # =================================== # op_interfaces_str = "" if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) op_traits_str = "" if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) # =================================== # # gen get input/output methods str # # =================================== # op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str( op_input_name_list, op_mutable_attribute_name_list, op_output_name_list, ) # =================================== # # gen Build methods str # # =================================== # build_args_with_muta_attr_not_input_for_declare = "" build_func_with_muta_attr_not_input = "" build_mutable_attr_is_input = "" build_func_with_muta_attr_is_input = "" if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, build_func_with_muta_attr_not_input, ) = GenBuild( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=False, ) if len(op_mutable_attribute_name_list) > 0: ( build_args_with_muta_attr_is_input_for_declare, build_func_with_muta_attr_is_input, ) = GenBuild( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, muta_attr_is_input=True, ) build_mutable_attr_is_input = "static void Build({build_args});".format( build_args=build_args_with_muta_attr_is_input_for_declare ) # gen op_declare_str/op_defined_str if len(op_non_mutable_attribute_name_list) == 0: op_declare_str = OP_DECLARE_TEMPLATE.format( op_name=op_class_name, dialect_op_name=op_dialect_name, interfaces=op_interfaces_str, traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) op_defined_str = "" else: op_declare_str = OP_DECLARE_TEMPLATE.format( op_name=op_class_name, dialect_op_name=op_dialect_name, interfaces=op_interfaces_str, traits=op_traits_str, attribute_declare=op_n_attribute_declare_str.format( attribute_num=len(op_non_mutable_attribute_name_list) ), attribute_num=len(op_non_mutable_attribute_name_list), build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) attribute_names_str = ( '"' + '", "'.join(op_non_mutable_attribute_name_list) + '"' ) op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( op_name=op_class_name, attribute_num=len(op_non_mutable_attribute_name_list), attribute_names=attribute_names_str, ) # =================================== # # gen GetOpInfo func str # # =================================== # # generate get op info funciton: inputs input_info_list = [] for idx in range(len(op_input_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( name=op_input_name_list[idx], typename=op_input_type_list[idx], optional=op_input_optional_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx], is_mutable_attribute='false', ) ) for idx in range(len(op_mutable_attribute_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], typename=op_mutable_attribute_type_list[idx][0], optional='false', no_need_buffer='false', is_mutable_attribute='true', ) ) if len(input_info_list) > 0: inputs_info_str = ", ".join(input_info_list) else: inputs_info_str = "" # generate get op info funciton: outputs outputs_info_str = "" if len(op_output_name_list) > 0: output_info_list = [] for idx in range(len(op_output_name_list)): output_info_list.append( CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( name=op_output_name_list[idx], typename=op_output_type_list[idx], optional=op_output_optional_list[idx], intermediate=op_output_intermediate_list[idx], ) ) outputs_info_str = ", ".join(output_info_list) # generate get op info funciton: attributes attribute_info_str = "" if len(op_non_mutable_attribute_name_list) > 0: attribute_info_list = [] for idx in range(len(op_non_mutable_attribute_name_list)): attribute_info_list.append( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( name=op_non_mutable_attribute_name_list[idx], typename=op_non_mutable_attribute_type_list[idx], data_type=op_non_mutable_attribute_data_type_list[ idx ], ) ) attribute_info_str = ", ".join(attribute_info_list) # generate runtiem info infer_meta_func_str = "" infer_meta_param_str = "" if op_infer_meta_map is not None: infer_meta_func_str = op_infer_meta_map['func'] infer_meta_param_str = '", "'.join(op_infer_meta_map['param']) kernel_func_str = "" kernel_param_str = "" kernel_key_dtype = "" if op_kernel_map is not None: kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_param_str = '", "'.join(op_kernel_map['param']) if 'data_type' in op_kernel_map and op_kernel_map['data_type']: kernel_key_dtype = '", "'.join( op_kernel_map['data_type']['candidates'] ) inplace_str = "" view_str = "" if op_name[-1] == "_": if op_inplace_map is not None: for key, value in op_inplace_map.items(): inplace_str += '{"' + key + '", "' + value + '"},' inplace_str = inplace_str[:-1] if op_view_map is not None: for key, value in op_view_map.items(): view_str += '{"' + key + '", "' + value + '"},' view_str = view_str[:-1] op_info_func_str = OP_INFO_TEMPLATE.format( op_name=op_class_name, inputs=inputs_info_str, attributes=attribute_info_str, outputs=outputs_info_str, infer_meta_func=infer_meta_func_str, infer_meta_param=infer_meta_param_str, kernel_func=kernel_func_str, kernel_param=kernel_param_str, kernel_key_dtype=kernel_key_dtype, inplace=inplace_str, view=view_str, ) # generate op verify function str op_verify_str = gen_verify_func_str( op_class_name, op_input_type_list, op_input_optional_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list, op_output_type_list, op_output_optional_list, ) op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) ops_defined_list.append(op_info_func_str) ops_defined_list.append(build_func_with_muta_attr_not_input) if len(op_mutable_attribute_name_list) > 0: ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) # (4) Generate head file str op_namespaces_prev = "" for name in namespaces: op_namespaces_prev += name + "::" ops_name_with_namespace_list = [] for name in ops_name_list: ops_name_with_namespace_list.append(op_namespaces_prev + name) op_list_str = GET_OP_LIST_TEMPALTE.format( ", ".join(ops_name_with_namespace_list) ) # Add GET_OP_LIST declare_type_id_str = "" for op in ops_name_with_namespace_list: declare_type_id_str += DECLARE_OP_TYPE_ID.format(op_name=op) head_file_str = "" head_file_str += "".join(ops_declare_list) # Add op class for name in reversed(namespaces): head_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=head_file_str ) # Add namespaces head_file_str = H_FILE_TEMPLATE.format( op_declare=op_list_str, input=head_file_str, declare_type_id=declare_type_id_str, ) # Add head # (5) Generate source file str source_file_str = "".join(ops_defined_list) # Add op define for name in reversed(namespaces): source_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=source_file_str ) # Add namespaces define_type_id_str = "" for op in ops_name_with_namespace_list: define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op) source_file_str = CC_FILE_TEMPLATE.format( h_file=op_def_h_file[:-4], input=source_file_str, define_type_id=define_type_id_str, ) # Add head # (5) Generate pd_op.h.tmp, pd_op.cc.tmp with open(op_def_h_file, 'a') as f: f.write(head_file_str) with open(op_def_cc_file, 'a') as f: f.write(source_file_str) # ===================================== # Script parameter parsing # ===================================== def ParseArguments(): parser = argparse.ArgumentParser( description='Generate Dialect OP Definition Files By Yaml' ) parser.add_argument('--op_yaml_files', type=str) parser.add_argument('--op_compat_yaml_file', type=str) parser.add_argument('--namespaces', type=str) parser.add_argument('--dialect_name', type=str) parser.add_argument('--op_def_h_file', type=str) parser.add_argument('--op_def_cc_file', type=str) return parser.parse_args() # ===================================== # Main # ===================================== if __name__ == "__main__": # parse arguments args = ParseArguments() op_yaml_files = args.op_yaml_files.split(",") op_compat_yaml_file = args.op_compat_yaml_file namespaces = [] if args.namespaces is not None: namespaces = args.namespaces.split(",") dialect_name = args.dialect_name op_def_h_file = args.op_def_h_file op_def_cc_file = args.op_def_cc_file # auto code generate OpGenerator( op_yaml_files, op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, op_def_cc_file, )