generate_op_map.py 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Notice: This file will be automatically executed during building of whole paddle project.
#         You can also run this file separately if you want to preview generated file without building.

import argparse
import json
import re

import yaml


def ParseArguments():
    parser = argparse.ArgumentParser(
        description='prim ops Code Generator Args Parser'
    )
    parser.add_argument('--ops_yaml_path', type=str, help="path to ops.yaml")
    parser.add_argument(
        '--ops_legacy_yaml_path', type=str, help="path to legacy_ops.yaml"
    )
    parser.add_argument(
        '--ops_compat_yaml_path', type=str, help="path to op_compat.yaml"
    )
    parser.add_argument(
        '--phi_ops_map_path',
        type=str,
        default="./phi_ops_map.py",
        help='path to target phi_ops_map.py',
    )

    args = parser.parse_args()
    return args


def _trans_value_type(item):
    for key in item.keys():
        for subkey in item[key]:
            value = str(item[key][subkey])
            item[key][subkey] = value


def generate_code(
    ops_yaml_path, ops_legacy_yaml_path, ops_compat_yaml_path, phi_ops_map_path
):
    """
    Generate dictiorary and save to file phi_ops_map.py. The target file records gap
    of description between current op and standard ones.
    """
    for op_path in [ops_yaml_path, ops_legacy_yaml_path]:
        pattern = re.compile(r'[(](.*)[)]', re.S)
        with open(op_path, "rt") as f:
            ops = yaml.safe_load(f)
            dct = {}
            for item in ops:
                key = item['op']
                if key in dct:
                    raise ValueError(f"There already exists op {key}")
                dct[key] = {
                    "args": re.findall(pattern, item["args"])[0],
                    "output": item["output"],
                }

        with open(ops_compat_yaml_path, "rt") as f:
            ops_compat = yaml.safe_load(f)
            map_dct = {}
            for item in ops_compat:
                key = item['op']
                if key.endswith(")"):
                    tmp = re.match("(.*)\\((.*)\\)", key.replace(" ", ""))
                    phi_name, op_name = tmp.group(1), tmp.group(2)
                    map_dct[op_name] = {"phi_name": phi_name}
                else:
                    op_name = key
                    map_dct[op_name] = {"phi_name": op_name}
                for element in ["inputs", "attrs"]:
                    if element in item.keys():
                        map_dct[op_name][element] = item[element]
                for element in ["scalar", "int_array"]:
                    if element in item.keys():
                        _trans_value_type(item[element])
                        map_dct[op_name][element] = item[element]

        with open(phi_ops_map_path, "w") as f:
            f.write("op_map = ")
            json.dump(map_dct, f, indent=4)
            f.write('\n')
            f.write("op_info = ")
            json.dump(dct, f, indent=4)
            f.write('\n')


if __name__ == "__main__":
    args = ParseArguments()
    ops_yaml_path = args.ops_yaml_path
    ops_legacy_yaml_path = args.ops_legacy_yaml_path
    ops_compat_yaml_path = args.ops_compat_yaml_path
    phi_ops_map_path = args.phi_ops_map_path
    generate_code(
        ops_yaml_path,
        ops_legacy_yaml_path,
        ops_compat_yaml_path,
        phi_ops_map_path,
    )