api_gen.py 8.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import argparse
18
import re
19

20
from api_base import BaseAPI
21

22

23
class ForwardAPI(BaseAPI):
24
    def __init__(self, api_item_yaml):
25
        super(ForwardAPI, self).__init__(api_item_yaml)
26 27
        self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
            api_item_yaml)
28 29 30 31 32 33 34 35 36

    def get_api_func_name(self):
        if self.is_dygraph_api:
            return self.api + '_intermediate'
        else:
            return self.api

    def parse_intermediate(self, api_item_yaml):
        if 'intermediate' in api_item_yaml:
37 38 39 40 41
            intermediate_outs = [
                item.strip()
                for item in api_item_yaml['intermediate'].split(',')
            ]
            return True, intermediate_outs
42
        else:
43
            return False, []
44 45 46 47 48

    def get_return_type(self, out_type_list):
        return out_type_list[0] if len(
            out_type_list) == 1 else "std::tuple<" + ",".join(
                out_type_list) + ">"
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    def gene_return_type_code(self):
        if self.is_dygraph_api or len(self.intermediate_outs) == 0:
            return self.outputs['return_type']
        else:
            return_out_list = []
            for i, name in enumerate(self.outputs['names']):
                if name not in self.intermediate_outs:
                    return_out_list.append(self.outputs['types'][i])
            return return_out_list[0] if len(
                return_out_list) == 1 else "std::tuple<" + ",".join(
                    return_out_list) + ">"

    def gene_return_code(self):
        if self.is_dygraph_api or len(self.intermediate_outs) == 0:
            return "api_output"
        else:
            return_out_list = []
            for i, name in enumerate(self.outputs['names']):
                if name not in self.intermediate_outs:
                    return_out_list.append(i)
            if len(return_out_list) == 1:
                return f"std::get<{return_out_list[0]}>(api_output)"
            else:
                selected_code = [
                    f"std::get<{i}>(api_output)" for i in return_out_list
                ]
            return '{' + ", ".join(selected_code) + '}'

78 79 80 81 82
    def gene_output(self,
                    output_type_list,
                    set_out_func,
                    code_indent,
                    inplace_flag=False):
Z
zyfncg 已提交
83
        kernel_output = ""
84
        output_names = []
Z
zyfncg 已提交
85 86 87
        output_create = ""

        if len(output_type_list) == 1:
88 89
            kernel_output = 'kernel_out'
            output_names.append('kernel_out')
90 91 92
            inplace_assign = " = " + self.inplace_map[self.outputs['names'][
                0]] if inplace_flag and self.inplace_map is not None and self.outputs[
                    'names'][0] in self.inplace_map else ""
Z
zyfncg 已提交
93
            output_create = f"""
94 95
{code_indent}  {self.outputs['return_type']} api_output{inplace_assign};
{code_indent}  auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
Z
zyfncg 已提交
96 97 98

        elif len(output_type_list) > 1:
            output_create = f"""
99
{code_indent}  {self.outputs['return_type']} api_output;"""
Z
zyfncg 已提交
100 101

            for i in range(len(output_type_list)):
102 103
                kernel_output = kernel_output + f'kernel_out_{i}, '
                output_names.append(f'kernel_out_{i}')
104 105 106
                if inplace_flag and self.inplace_map is not None and self.outputs[
                        'names'][i] in self.inplace_map:
                    output_create = output_create + f"""
107
{code_indent}  std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
108

Z
zyfncg 已提交
109
                output_create = output_create + f"""
110
{code_indent}  auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
Z
zyfncg 已提交
111 112 113 114 115 116 117

            kernel_output = kernel_output[:-2]
        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
                    self.api))

118
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
119

120 121 122

def header_include():
    return """
123 124
#include <tuple>

125 126 127
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
128
#include "paddle/utils/optional.h"
129 130 131 132 133 134 135 136 137 138
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

139
#include "paddle/phi/api/lib/api_custom_impl.h"
140
#include "paddle/phi/api/lib/api_registry.h"
141
#include "paddle/phi/api/lib/api_gen_utils.h"
142 143 144 145 146 147 148 149
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
H
hong 已提交
150
#include "paddle/phi/infermeta/ternary.h"
151 152 153
#include "paddle/phi/kernels/declarations.h"

#include "paddle/fluid/platform/profiler/event_tracing.h"
154 155 156 157 158
"""


def api_register():
    return """
159
PD_REGISTER_API(Math);
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
"""


def api_namespace():
    return ("""
namespace paddle {
namespace experimental {

""", """

}  // namespace experimental
}  // namespace paddle
""")


175 176
def generate_api(api_yaml_path, header_file_path, source_file_path,
                 dygraph_header_file_path, dygraph_source_file_path):
177 178 179 180 181

    with open(api_yaml_path, 'r') as f:
        apis = yaml.load(f, Loader=yaml.FullLoader)
    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')
182 183
    dygraph_header_file = open(dygraph_header_file_path, 'w')
    dygraph_source_file = open(dygraph_source_file_path, 'w')
184 185 186 187 188 189 190

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

191
    include_header_file = "paddle/phi/api/include/api.h"
192 193 194
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

195 196 197 198
    dygraph_header_file.write("#pragma once\n")
    dygraph_header_file.write(header_include())
    dygraph_header_file.write(namespace[0])

199
    dygraph_include_header_file = "paddle/phi/api/lib/dygraph_api.h"
200 201 202
    dygraph_source_file.write(source_include(dygraph_include_header_file))
    dygraph_source_file.write(namespace[0])

203
    for api in apis:
204 205 206 207
        foward_api = ForwardAPI(api)
        if foward_api.is_dygraph_api:
            dygraph_header_file.write(foward_api.gene_api_declaration())
            dygraph_source_file.write(foward_api.gene_api_code())
208 209 210 211

            foward_api.is_dygraph_api = False
            header_file.write(foward_api.gene_api_declaration())
            source_file.write(foward_api.gene_api_code())
212 213 214
        else:
            header_file.write(foward_api.gene_api_declaration())
            source_file.write(foward_api.gene_api_code())
215 216 217

    header_file.write(namespace[1])
    source_file.write(namespace[1])
218

219 220 221
    dygraph_header_file.write(namespace[1])
    dygraph_source_file.write(namespace[1])

222 223 224 225 226
    source_file.write(api_register())

    header_file.close()
    source_file.close()

227 228 229
    dygraph_header_file.close()
    dygraph_source_file.close()

230 231 232 233 234 235

def main():
    parser = argparse.ArgumentParser(
        description='Generate PaddlePaddle C++ API files')
    parser.add_argument(
        '--api_yaml_path',
236
        help='path to api yaml file',
237
        default='python/paddle/utils/code_gen/api.yaml')
238

239 240 241
    parser.add_argument(
        '--api_header_path',
        help='output of generated api header code file',
242
        default='paddle/phi/api/include/api.h')
243 244 245 246

    parser.add_argument(
        '--api_source_path',
        help='output of generated api source code file',
247
        default='paddle/phi/api/lib/api.cc')
248

249 250 251
    parser.add_argument(
        '--dygraph_api_header_path',
        help='output of generated dygraph api header code file',
252
        default='paddle/phi/api/lib/dygraph_api.h')
253 254 255 256

    parser.add_argument(
        '--dygraph_api_source_path',
        help='output of generated dygraph api source code file',
257
        default='paddle/phi/api/lib/dygraph_api.cc')
258

259 260 261 262 263
    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    header_file_path = options.api_header_path
    source_file_path = options.api_source_path
264 265
    dygraph_header_file_path = options.dygraph_api_header_path
    dygraph_source_file_path = options.dygraph_api_source_path
266

267 268
    generate_api(api_yaml_path, header_file_path, source_file_path,
                 dygraph_header_file_path, dygraph_source_file_path)
269 270 271 272


if __name__ == '__main__':
    main()