ops_extra_info_gen.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16 17 18
import re

import yaml
19 20


21
def map_code_template(attrs_str, attrs_checker_str):
22
    return f"""// This file is generated by paddle/fluid/operators/generator/ops_extra_info_gen.py
23 24
#include "paddle/fluid/operators/ops_extra_info.h"

25
#include "paddle/phi/backends/gpu/cuda/cudnn_workspace_helper.h"
26 27

namespace paddle {{
28
namespace operators {{
29

30 31 32 33 34 35 36 37 38
ExtraInfoUtils::ExtraInfoUtils() {{
  g_extra_attrs_map_ = {{
    {attrs_str}
  }};

  g_extra_attrs_checker_ = {{
    {attrs_checker_str}
  }};
}}
39

40 41
}}  // namespace operators
}}  // namespace paddle
42 43 44 45 46 47 48 49 50 51 52 53 54 55
"""


ATTR_TYPE_STRING_MAP = {
    'bool': 'bool',
    'int': 'int',
    'int64_t': 'int64_t',
    'float': 'float',
    'double': 'double',
    'str': 'std::string',
    'int[]': 'std::vector<int>',
    'int64_t[]': 'std::vector<int64_t>',
    'float[]': 'std::vector<float>',
    'double[]': 'std::vector<double>',
56
    'str[]': 'std::vector<std::string>',
57 58 59 60 61
}


def parse_attr(attr_str):
    result = re.search(
62
        r"(?P<attr_type>[a-zA-Z0-9_[\]]+)\s+(?P<name>[a-zA-Z0-9_]+)\s*=\s*(?P<default_val>\S+)",
63 64 65 66 67 68 69
        attr_str,
    )
    return (
        ATTR_TYPE_STRING_MAP[result.group('attr_type')],
        result.group('name'),
        result.group('default_val'),
    )
70 71


C
Chen Weihang 已提交
72
def generate_extra_info(op_compat_yaml_path, ops_extra_info_path):
73
    compat_apis = []
C
Chen Weihang 已提交
74
    with open(op_compat_yaml_path, 'rt') as f:
75 76
        compat_apis = yaml.safe_load(f)

77 78 79 80 81 82 83
    def get_op_name(api_item):
        names = api_item.split('(')
        if len(names) == 1:
            return names[0].strip()
        else:
            return names[1].split(')')[0].strip()

84
    extra_map_str_list = []
85
    extra_checker_str_list = []
86

C
Chen Weihang 已提交
87 88 89
    for op_compat_args in compat_apis:
        if 'extra' in op_compat_args:
            extra_args_map = op_compat_args['extra']
90 91 92
            # TODO(chenweihang): add inputs and outputs
            if 'attrs' in extra_args_map:
                attr_map_list = []
93
                attr_checker_func_list = []
94 95
                for attr in extra_args_map['attrs']:
                    attr_type, attr_name, default_val = parse_attr(attr)
96 97 98
                    attr_checker_func_list.append(
                        f"[](framework::AttributeMap* attr_map, bool only_check_exist_value)-> void {{ ExtraAttrChecker<{attr_type}>(\"{attr_name}\", {default_val})(attr_map, only_check_exist_value);}}"
                    )
99 100
                    if attr_type.startswith("std::vector"):
                        attr_map_list.append(
101 102
                            f"{{\"{attr_name}\", {attr_type}{default_val}}}"
                        )
103 104 105 106 107
                    else:
                        attr_map_list.append(
                            f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}"
                        )
                api_extra_attr_map = ", ".join(attr_map_list)
108
                api_extra_attr_checkers = ",\n      ".join(
109 110
                    attr_checker_func_list
                )
111
                extra_map_str_list.append(
112
                    f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_map} }}}}"
113
                )
114
                extra_checker_str_list.append(
115
                    f"{{\"{get_op_name(op_compat_args['op'])}\", {{ {api_extra_attr_checkers} }}}}"
116
                )
C
Chen Weihang 已提交
117 118
                if 'backward' in op_compat_args:
                    for bw_item in op_compat_args['backward'].split(','):
119 120
                        bw_op_name = get_op_name(bw_item)
                        extra_map_str_list.append(
121 122
                            f"{{\"{bw_op_name}\", {{ {api_extra_attr_map} }}}}"
                        )
123 124 125
                        extra_checker_str_list.append(
                            f"{{\"{bw_op_name}\", {{ {api_extra_attr_checkers} }}}}"
                        )
126 127

    ops_extra_info_file = open(ops_extra_info_path, 'w')
128
    ops_extra_info_file.write(
129 130 131 132 133
        map_code_template(
            ",\n    ".join(extra_map_str_list),
            ",\n    ".join(extra_checker_str_list),
        )
    )
134 135 136 137 138
    ops_extra_info_file.close()


def main():
    parser = argparse.ArgumentParser(
139 140 141 142 143 144 145 146 147 148 149 150 151
        description='Generate PaddlePaddle Extra Param Info for Op'
    )
    parser.add_argument(
        '--op_compat_yaml_path',
        help='path to api compat yaml file',
        default='paddle/phi/api/yaml/op_compat.yaml',
    )

    parser.add_argument(
        '--ops_extra_info_path',
        help='output of generated extra_prama_info code file',
        default='paddle/fluid/operators/ops_extra_info.cc',
    )
152 153 154

    options = parser.parse_args()

C
Chen Weihang 已提交
155
    op_compat_yaml_path = options.op_compat_yaml_path
156 157
    ops_extra_info_path = options.ops_extra_info_path

C
Chen Weihang 已提交
158
    generate_extra_info(op_compat_yaml_path, ops_extra_info_path)
159 160 161 162


if __name__ == '__main__':
    main()