wrapped_infermeta_gen.py 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

17
import yaml
18
from api_gen import ForwardAPI
19

20 21
kernel_func_set = set()

22 23 24 25 26 27

def get_wrapped_infermeta_name(api_name):
    return api_name.capitalize() + 'InferMeta'


def gene_wrapped_infermeta_and_register(api):
28
    if api.is_base_api and not api.is_dygraph_api:
29
        register_code = f"""
30
PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']});"""
31 32

        if api.infer_meta['param'] is not None:
33 34 35
            if api.kernel['func'][0] in kernel_func_set:
                return '', '', ''

36 37 38 39 40 41
            kernel_params = api.kernel['param']
            if kernel_params is None:
                kernel_params = api.inputs['names'] + api.attrs['names']
            if kernel_params == api.infer_meta['param']:
                return '', '', register_code

42 43 44
            assert len(api.infer_meta['param']) <= len(
                kernel_params
            ), f"{api.api} api: Parameters error. The params of infer_meta should be a subset of kernel params."
45

46 47
            tensor_type_map = {
                'const Tensor&': 'const MetaTensor&',
48
                'const std::vector<Tensor>&': 'const std::vector<const MetaTensor*>&',
49
                'Tensor': 'MetaTensor*',
50
                'std::vector<Tensor>': 'std::vector<MetaTensor*>',
51
                'const paddle::optional<Tensor>&': 'const MetaTensor&',
52
            }
53

54
            wrapped_infermeta_name = get_wrapped_infermeta_name(
55 56
                api.kernel['func'][0]
            )
57 58
            args = []
            for input_name in api.inputs['names']:
59
                if input_name in kernel_params:
60
                    args.append(
61 62 63 64
                        tensor_type_map[api.inputs['input_info'][input_name]]
                        + ' '
                        + input_name
                    )
65
            for attr_name in api.attrs['names']:
66
                if attr_name in kernel_params:
67 68 69
                    args.append(
                        api.attrs['attr_info'][attr_name][0] + ' ' + attr_name
                    )
70
            for i, out_type in enumerate(api.outputs['types']):
71 72 73
                args.append(
                    tensor_type_map[out_type] + ' ' + api.outputs['names'][i]
                )
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

            invoke_param = api.infer_meta['param']
            invoke_param.extend(api.outputs['names'])

            declare_code = f"""
void {wrapped_infermeta_name}({", ".join(args)});
"""

            defind_code = f"""
void {wrapped_infermeta_name}({", ".join(args)}) {{
  {api.infer_meta['func']}({", ".join(invoke_param)});
}}
"""

            register_code = f"""
89
PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{get_wrapped_infermeta_name(api.kernel['func'][0])});"""
90

91
            kernel_func_set.add(api.kernel['func'][0])
92 93 94 95 96 97 98 99 100
            return declare_code, defind_code, register_code
        else:
            return '', '', register_code
    else:
        return '', '', ''


def header_include():
    return """
101 102
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/common/scalar.h"
103
#include "paddle/phi/common/int_array.h"
104 105 106 107 108 109
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
110 111 112 113 114
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
H
hong 已提交
115
#include "paddle/phi/infermeta/ternary.h"
116 117 118 119
"""


def api_namespace():
120 121
    return (
        """
122
namespace phi {
123 124
""",
        """
125
}  // namespace phi
126 127
""",
    )
128 129


130 131 132
def generate_wrapped_infermeta_and_register(
    api_yaml_path, header_file_path, source_file_path
):
133 134 135 136 137 138
    apis = []
    for each_api_yaml in api_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                apis.extend(api_list)
139 140 141 142 143 144 145 146 147 148

    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

149
    include_header_file = "paddle/phi/infermeta/generated.h"
150 151 152 153 154 155
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    infermeta_register_code = ''

    for api in apis:
156
        api_item = ForwardAPI(api)
157 158 159 160 161
        (
            declare_code,
            defind_code,
            register_code,
        ) = gene_wrapped_infermeta_and_register(api_item)
162 163
        header_file.write(declare_code)
        source_file.write(defind_code)
164 165
        if infermeta_register_code.find(register_code) == -1:
            infermeta_register_code = infermeta_register_code + register_code
166 167 168 169 170 171 172 173 174 175 176 177

    header_file.write(namespace[1])
    source_file.write(namespace[1])

    source_file.write(infermeta_register_code)

    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
178 179 180 181 182 183
        description='Generate PaddlePaddle C++ API files'
    )
    parser.add_argument(
        '--api_yaml_path',
        help='path to api yaml file',
        nargs='+',
184
        default=['paddle/phi/api/yaml/ops.yaml'],
185
    )
186 187 188
    parser.add_argument(
        '--wrapped_infermeta_header_path',
        help='output of generated wrapped_infermeta header code file',
189 190
        default='paddle/phi/infermeta/generated.h',
    )
191 192 193 194

    parser.add_argument(
        '--wrapped_infermeta_source_path',
        help='output of generated wrapped_infermeta source code file',
195 196
        default='paddle/phi/infermeta/generated.cc',
    )
197 198 199 200 201 202 203

    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    header_file_path = options.wrapped_infermeta_header_path
    source_file_path = options.wrapped_infermeta_source_path

204 205 206
    generate_wrapped_infermeta_and_register(
        api_yaml_path, header_file_path, source_file_path
    )
207 208 209 210


if __name__ == '__main__':
    main()