# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os from pathlib import Path import yaml from filters import ( cartesian_prod_mapping, to_composite_grad_opmaker_name, to_input_name, to_int_array_tensor_name, to_int_array_tensors_name, to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case, to_scalar_tensor_name, ) from generate_op import replace_compat_name from jinja2 import Environment, FileSystemLoader, StrictUndefined from parse_utils import to_named_dict from tests import ( is_base_op, is_initializer_list, is_scalar, is_vec, supports_inplace, supports_no_need_buffer, ) file_loader = FileSystemLoader(Path(__file__).parent / "templates") env = Environment( loader=file_loader, keep_trailing_newline=True, trim_blocks=True, lstrip_blocks=True, undefined=StrictUndefined, extensions=['jinja2.ext.do'], ) env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_pascal_case"] = to_pascal_case env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name env.filters["to_input_name"] = to_input_name env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.tests["base_op"] = is_base_op env.tests["vec"] = is_vec env.tests["scalar"] = is_scalar env.tests["initializer_list"] = is_initializer_list env.tests["supports_inplace"] = supports_inplace env.tests["supports_no_need_buffer"] = supports_no_need_buffer def restruct_io(op): op["input_dict"] = to_named_dict(op["inputs"]) op["attr_dict"] = to_named_dict(op["attrs"]) op["output_dict"] = to_named_dict(op["outputs"]) return op def main( ops_yaml_path, op_compat_yaml_path, op_version_yaml_path, output_op_path, output_arg_map_path, ): with open(ops_yaml_path, "rt") as f: ops = yaml.safe_load(f) ops = [restruct_io(op) for op in ops] forward_op_dict = to_named_dict(ops) with open(op_version_yaml_path, "rt") as f: op_versions = yaml.safe_load(f) # add op version info into op for op_version in op_versions: if op_version['op'] in forward_op_dict: forward_op_dict[op_version['op']]['version'] = op_version['version'] with open(op_compat_yaml_path, "rt") as f: op_op_map = yaml.safe_load(f) for op in ops: op['op_name'] = op['name'] replace_compat_name(op_op_map, forward_op_dict, {}) if len(ops) == 0: if os.path.isfile(output_op_path): os.remove(output_op_path) if os.path.isfile(output_arg_map_path): os.remove(output_arg_map_path) return op_template = env.get_template('op.c.j2') with open(output_op_path, "wt") as f: msg = op_template.render( ops=ops, backward_ops=[], op_dict=forward_op_dict, composite_gen_flag=False, ) f.write(msg) ks_template = env.get_template('ks.c.j2') with open(output_arg_map_path, 'wt') as f: msg = ks_template.render(ops=ops, backward_ops=[]) f.write(msg) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate operator file from op yaml." ) parser.add_argument( '--ops_yaml_path', type=str, help="parsed static ops yaml file." ) parser.add_argument( '--op_compat_yaml_path', type=str, help="ops args compat yaml file." ) parser.add_argument( '--op_version_yaml_path', type=str, help="ops version yaml file." ) parser.add_argument( "--output_op_path", type=str, help="path to save generated operators." ) parser.add_argument( "--output_arg_map_path", type=str, help="path to save generated argument mapping functions.", ) args = parser.parse_args() main( args.ops_yaml_path, args.op_compat_yaml_path, args.op_version_yaml_path, args.output_op_path, args.output_arg_map_path, )