api_gen.py 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
W
WangZhen 已提交
17
import re
18 19 20 21 22 23 24 25 26 27 28 29 30

import yaml
from op_gen import OpCompatParser, OpInfoParser, to_pascal_case

H_FILE_TEMPLATE = """

#pragma once

#include <vector>

#include "paddle/ir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
W
WangZhen 已提交
31
#include "paddle/phi/common/scalar.h"
32
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h"
33 34 35 36 37 38 39

{body}

"""

CPP_FILE_TEMPLATE = """

40 41 42
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"

{body}

"""


NAMESPACE_TEMPLATE = """
namespace {namespace} {{
{body}
}} // namespace {namespace}
"""


API_DECLARE_TEMPLATE = """
{ret_type} {api_name}({args});
"""


API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
    {in_combine}
    {compute_op}
W
WangZhen 已提交
67
    {out_split}
68 69 70 71 72
    {return_result}
}}

"""

W
WangZhen 已提交
73 74 75
COMBINE_OP_TEMPLATE = """
    auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""

W
WangZhen 已提交
76 77
SPLIT_OP_TEMPLATE = """
    auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>({in_name});"""
W
WangZhen 已提交
78 79 80 81

COMPUTE_OP_TEMPLATE = """
    paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""

82 83
OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType'
W
WangZhen 已提交
84
PD_MANUAL_OP_LIST = ['add_n']
85 86 87 88 89 90 91 92 93 94


def get_op_class_name(op_name):
    return to_pascal_case(op_name) + 'Op'


class CodeGen:
    def __init__(self) -> None:
        self._type_map = {
            'paddle::dialect::DenseTensorType': 'ir::OpResult',
W
WangZhen 已提交
95
            'paddle::dialect::SelectedRowsType': 'ir::OpResult',
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            'ir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<ir::OpResult>',
        }

    def _parse_yaml(self, op_yaml_files, op_compat_yaml_file):
        op_compat_parser = OpCompatParser(op_compat_yaml_file)

        op_yaml_items = []
        for yaml_file in op_yaml_files:
            with open(yaml_file, "r") as f:
                ops = yaml.safe_load(f)
                op_yaml_items = op_yaml_items + ops
        op_info_items = []
        for op in op_yaml_items:
            op_info_items.append(
                OpInfoParser(op, op_compat_parser.get_compat(op['name']))
            )
        return op_info_items

    # =====================================
    # Gen declare functions
    # =====================================
    def _gen_api_inputs(self, op_info):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        assert len(name_list) == len(type_list)
        ret = []
        for name, type in zip(name_list, type_list):
            ret.append(f'{self._type_map[type]} {name}')
        return ', '.join(ret)

W
WangZhen 已提交
126
    def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
127 128 129
        name_list = op_info.attribute_name_list
        type_list = op_info.attribute_build_arg_type_list
        default_value_list = op_info.attribute_default_value_list
W
WangZhen 已提交
130
        mutable_name_list = op_info.mutable_attribute_name_list
131
        assert len(name_list) == len(type_list) == len(default_value_list)
W
WangZhen 已提交
132 133
        no_mutable_attr = []
        mutable_attr = []
134 135 136
        for name, type, default_value in zip(
            name_list, type_list, default_value_list
        ):
W
WangZhen 已提交
137 138 139
            if is_mutable_attr and name in mutable_name_list:
                mutable_attr.append(f'{OP_RESULT} {name}')
                continue
140
            if with_default and default_value is not None:
W
WangZhen 已提交
141 142
                if type in ['float', 'double']:
                    default_value = default_value.strip('"')
W
WangZhen 已提交
143
                no_mutable_attr.append(
144 145 146 147 148
                    '{type} {name} = {default_value}'.format(
                        type=type, name=name, default_value=default_value
                    )
                )
            else:
W
WangZhen 已提交
149 150
                no_mutable_attr.append(f'{type} {name}')
        return ', '.join(mutable_attr + no_mutable_attr)
151

W
WangZhen 已提交
152
    def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):
153
        inputs = self._gen_api_inputs(op_info)
W
WangZhen 已提交
154
        attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr)
155 156
        return (inputs + ', ' + attrs).strip(', ')

W
WangZhen 已提交
157 158 159 160 161 162 163 164
    def _gen_ret_type(self, op_info):
        type_list = op_info.output_type_list
        if len(type_list) > 1:
            return 'std::tuple<{}>'.format(
                ', '.join([self._type_map[type] for type in type_list])
            )
        elif len(type_list) == 1:
            return self._type_map[type_list[0]]
W
WangZhen 已提交
165 166
        elif len(type_list) == 0:
            return 'void'
W
WangZhen 已提交
167

W
WangZhen 已提交
168
    def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
169
        return API_DECLARE_TEMPLATE.format(
W
WangZhen 已提交
170
            ret_type=self._gen_ret_type(op_info),
171
            api_name=op_name,
W
WangZhen 已提交
172
            args=self._gen_api_args(op_info, True, is_mutable_attr),
173 174 175 176 177 178
        )

    def _gen_h_file(self, op_info_items, namespaces, h_file_path):
        declare_str = ''
        for op_info in op_info_items:
            for op_name in op_info.op_phi_name:
W
WangZhen 已提交
179 180 181 182 183 184
                # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
                # is wrong, so temporarily skip the automatic generation of these APIs
                if (
                    op_info.infer_meta_func is None
                    and op_name not in PD_MANUAL_OP_LIST
                ):
185
                    continue
W
WangZhen 已提交
186 187 188 189
                declare_str += self._gen_one_declare(op_info, op_name, False)
                if len(op_info.mutable_attribute_name_list) > 0:
                    declare_str += self._gen_one_declare(op_info, op_name, True)

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        body = declare_str
        for namespace in reversed(namespaces):
            body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
        with open(h_file_path, 'w') as f:
            f.write(H_FILE_TEMPLATE.format(body=body))

    # =====================================
    # Gen impl functions
    # =====================================
    def _gen_in_combine(self, op_info):
        name_list = op_info.input_name_list
        type_list = op_info.input_type_list
        assert len(name_list) == len(type_list)
        combine_op = ''
        combine_op_list = []
        for name, type in zip(name_list, type_list):
            if VECTOR_TYPE in type:
                op_name = f'{name}_combine_op'
                combine_op += COMBINE_OP_TEMPLATE.format(
                    op_name=op_name, in_name=name
                )
                combine_op_list.append(op_name)
            else:
                combine_op_list.append(None)
        return combine_op, combine_op_list

W
WangZhen 已提交
216 217 218
    def _gen_compute_op_args(
        self, op_info, in_combine_op_list, is_mutable_attr
    ):
219
        input_name_list = op_info.input_name_list
W
WangZhen 已提交
220 221 222
        all_attr_list = op_info.attribute_name_list
        no_mutable_attr_list = op_info.non_mutable_attribute_name_list
        mutable_attr_list = op_info.mutable_attribute_name_list
223 224 225 226 227 228 229
        assert len(input_name_list) == len(in_combine_op_list)
        ret = []
        for input_name, combine_op in zip(input_name_list, in_combine_op_list):
            if combine_op is None:
                ret.append(input_name)
            else:
                ret.append(f'{combine_op}.out()')
W
WangZhen 已提交
230 231 232 233
        if is_mutable_attr:
            ret += list(mutable_attr_list + no_mutable_attr_list)
        else:
            ret += list(all_attr_list)
234 235
        return ', '.join(ret)

W
WangZhen 已提交
236 237 238
    def _gen_compute_op(
        self, op_info, op_name, in_combine_op_list, is_mutable_attr
    ):
239 240 241 242 243 244
        op_class_name = to_pascal_case(op_name) + 'Op'
        op_inst_name = op_name + '_op'
        return (
            COMPUTE_OP_TEMPLATE.format(
                op_class_name=op_class_name,
                op_inst_name=op_inst_name,
W
WangZhen 已提交
245 246 247
                args=self._gen_compute_op_args(
                    op_info, in_combine_op_list, is_mutable_attr
                ),
248 249 250 251
            ),
            op_inst_name,
        )

W
WangZhen 已提交
252
    def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
W
WangZhen 已提交
253 254
        name_list = op_info.output_name_list
        type_list = op_info.output_type_list
255

W
WangZhen 已提交
256
        split_op_str = ''
W
WangZhen 已提交
257 258 259
        ret_list = []
        for i, (name, type) in enumerate(zip(name_list, type_list)):
            if VECTOR_TYPE in type:
W
WangZhen 已提交
260 261 262
                split_op_name = f'{name}_split_op'
                split_op_str += SPLIT_OP_TEMPLATE.format(
                    op_name=split_op_name, in_name=f'{op_inst_name}.result({i})'
W
WangZhen 已提交
263
                )
W
WangZhen 已提交
264
                ret_list.append(f'{split_op_name}.outputs()')
W
WangZhen 已提交
265 266
            else:
                ret_list.append(f'{op_inst_name}.result({i})')
W
WangZhen 已提交
267
        return split_op_str, ret_list
268

W
WangZhen 已提交
269 270 271
    def _gen_return_result(self, ret_list):
        if len(ret_list) > 1:
            return 'return std::make_tuple({});'.format(', '.join(ret_list))
W
WangZhen 已提交
272
        elif len(ret_list) == 1:
W
WangZhen 已提交
273
            return f'return {ret_list[0]};'
W
WangZhen 已提交
274 275
        elif len(ret_list) == 0:
            return 'return;'
276

W
WangZhen 已提交
277
    def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
278 279
        in_combine, in_combine_op_list = self._gen_in_combine(op_info)
        compute_op, op_inst_name = self._gen_compute_op(
W
WangZhen 已提交
280
            op_info, op_name, in_combine_op_list, is_mutable_attr
281
        )
W
WangZhen 已提交
282
        out_split, ret_list = self._gen_out_split_and_ret_list(
W
WangZhen 已提交
283 284
            op_info, op_inst_name
        )
285

W
WangZhen 已提交
286 287
        ret = API_IMPL_TEMPLATE.format(
            ret_type=self._gen_ret_type(op_info),
288
            api_name=op_name,
W
WangZhen 已提交
289
            args=self._gen_api_args(op_info, False, is_mutable_attr),
290 291
            in_combine=in_combine,
            compute_op=compute_op,
W
WangZhen 已提交
292
            out_split=out_split,
W
WangZhen 已提交
293 294 295 296 297
            return_result=self._gen_return_result(ret_list),
        )

        ret = re.sub(r' +\n', '', ret)
        return ret
298 299 300 301 302

    def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
        impl_str = ''
        for op_info in op_info_items:
            for op_name in op_info.op_phi_name:
W
WangZhen 已提交
303 304 305 306 307 308
                # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
                # is wrong, so temporarily skip the automatic generation of these APIs
                if (
                    op_info.infer_meta_func is None
                    and op_name not in PD_MANUAL_OP_LIST
                ):
309
                    continue
W
WangZhen 已提交
310 311 312
                impl_str += self._gen_one_impl(op_info, op_name, False)
                if len(op_info.mutable_attribute_name_list) > 0:
                    impl_str += self._gen_one_impl(op_info, op_name, True)
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
        body = impl_str
        for namespace in reversed(namespaces):
            body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
        with open(cpp_file_path, 'w') as f:
            f.write(CPP_FILE_TEMPLATE.format(body=body))

    def gen_h_and_cpp_file(
        self,
        op_yaml_files,
        op_compat_yaml_file,
        namespaces,
        h_file_path,
        cpp_file_path,
    ):
        if os.path.exists(h_file_path):
            os.remove(h_file_path)
        if os.path.exists(cpp_file_path):
            os.remove(cpp_file_path)

        op_info_items = self._parse_yaml(op_yaml_files, op_compat_yaml_file)

        self._gen_h_file(op_info_items, namespaces, h_file_path)
        self._gen_cpp_file(op_info_items, namespaces, cpp_file_path)


def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Generate Dialect API Files By Yaml'
    )
    parser.add_argument('--op_yaml_files', type=str)
    parser.add_argument('--op_compat_yaml_file', type=str)
    parser.add_argument('--namespaces', type=str)
    parser.add_argument('--api_def_h_file', type=str)
    parser.add_argument('--api_def_cc_file', type=str)
    return parser.parse_args()


if __name__ == '__main__':
    args = ParseArguments()

    op_yaml_files = args.op_yaml_files.split(",")
    op_compat_yaml_file = args.op_compat_yaml_file
    if args.namespaces is not None:
        namespaces = args.namespaces.split(",")
    api_def_h_file = args.api_def_h_file
    api_def_cc_file = args.api_def_cc_file

    code_gen = CodeGen()
    code_gen.gen_h_and_cpp_file(
        op_yaml_files,
        op_compat_yaml_file,
        namespaces,
        api_def_h_file,
        api_def_cc_file,
    )