api_gen.py 10.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import argparse
18
import re
19

20
from api_base import BaseAPI, PREFIX_TENSOR_NAME
21

22 23 24 25 26
inplace_out_type_map = {
    "Tensor": "Tensor&",
    "std::vector<Tensor>": "std::vector<Tensor>&"
}

27

28
class ForwardAPI(BaseAPI):
29
    def __init__(self, api_item_yaml):
30
        super(ForwardAPI, self).__init__(api_item_yaml)
31 32
        self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
            api_item_yaml)
33 34 35 36 37 38 39 40 41

    def get_api_func_name(self):
        if self.is_dygraph_api:
            return self.api + '_intermediate'
        else:
            return self.api

    def parse_intermediate(self, api_item_yaml):
        if 'intermediate' in api_item_yaml:
42 43 44 45 46
            intermediate_outs = [
                item.strip()
                for item in api_item_yaml['intermediate'].split(',')
            ]
            return True, intermediate_outs
47
        else:
48
            return False, []
49

50 51 52 53 54 55 56 57
    def get_return_type_with_intermediate(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
                out_type_list.append(inplace_out_type_map[out_type])
            else:
                out_type_list.append(out_type)
58

59 60
        if len(out_type_list) == 1:
            return out_type_list[0]
61
        else:
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
            return "std::tuple<" + ", ".join(out_type_list) + ">"

    def get_return_type(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
                out_type_list.append(inplace_out_type_map[out_type])
            elif self.is_dygraph_api or out_name not in self.intermediate_outs:
                out_type_list.append(out_type)

        if len(out_type_list) == 1:
            return out_type_list[0]
        else:
            return "std::tuple<" + ", ".join(out_type_list) + ">"
77 78 79

    def gene_return_code(self):
        if self.is_dygraph_api or len(self.intermediate_outs) == 0:
80
            return "return api_output;"
81 82 83
        else:
            return_out_list = []
            for i, name in enumerate(self.outputs['names']):
84
                if name.split('@')[0] not in self.intermediate_outs:
85 86
                    return_out_list.append(i)
            if len(return_out_list) == 1:
87
                return f"return std::get<{return_out_list[0]}>(api_output);"
88 89 90 91
            else:
                selected_code = [
                    f"std::get<{i}>(api_output)" for i in return_out_list
                ]
92
            return 'return {' + ", ".join(selected_code) + '};'
93

94 95 96 97 98
    def gene_output(self,
                    output_type_list,
                    set_out_func,
                    code_indent,
                    inplace_flag=False):
Z
zyfncg 已提交
99
        kernel_output = ""
100
        output_names = []
Z
zyfncg 已提交
101
        output_create = ""
102
        return_type = self.get_return_type_with_intermediate(inplace_flag)
Z
zyfncg 已提交
103 104

        if len(output_type_list) == 1:
105 106
            kernel_output = 'kernel_out'
            output_names.append('kernel_out')
107
            inplace_assign = " = " + self.inplace_map[self.outputs['names'][
108 109
                0]] if inplace_flag and self.outputs['names'][
                    0] in self.inplace_map else ""
Z
zyfncg 已提交
110
            output_create = f"""
111
{code_indent}  {return_type} api_output{inplace_assign};"""
112

113
            if return_type == 'std::vector<Tensor>':
114 115 116 117 118 119 120
                assert self.outputs['out_size_expr'] is not None, \
                     f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                output_create = output_create + f"""
{code_indent}  auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);"""

            else:
                output_create = output_create + f"""
121
{code_indent}  auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
Z
zyfncg 已提交
122

123 124 125 126 127 128 129
            if not inplace_flag and self.view_map is not None and self.outputs[
                    'names'][0] in self.view_map:
                output_create = output_create + f"""
{code_indent}  kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""

Z
zyfncg 已提交
130 131
        elif len(output_type_list) > 1:
            output_create = f"""
132 133 134 135 136 137 138 139 140 141 142 143 144
{code_indent}  {return_type} api_output;"""

            if inplace_flag:
                output_create = f"""
{code_indent}  {return_type} api_output{{"""

                for out_name in self.outputs['names']:
                    if out_name in self.inplace_map:
                        output_create = output_create + self.inplace_map[
                            out_name] + ', '
                    else:
                        output_create += 'Tensor(), '
                output_create = output_create[:-2] + '};'
Z
zyfncg 已提交
145 146

            for i in range(len(output_type_list)):
147 148
                kernel_output = kernel_output + f'kernel_out_{i}, '
                output_names.append(f'kernel_out_{i}')
149

150 151 152 153 154 155 156 157
                if output_type_list[i] == 'std::vector<Tensor>':
                    assert self.outputs['out_size_expr'][i] is not None, \
                        f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                    output_create = output_create + f"""
{code_indent}  auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, &std::get<{i}>(api_output));"""

                else:
                    output_create = output_create + f"""
158
{code_indent}  auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
Z
zyfncg 已提交
159

160 161 162 163 164 165 166
                if not inplace_flag and self.view_map is not None and self.outputs[
                        'names'][i] in self.view_map:
                    output_create = output_create + f"""
{code_indent}  kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent}  kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""

Z
zyfncg 已提交
167 168 169 170 171 172
            kernel_output = kernel_output[:-2]
        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
                    self.api))

173
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
174

175 176 177

def header_include():
    return """
178 179
#include <tuple>

180 181
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
182
#include "paddle/phi/common/int_array.h"
183
#include "paddle/utils/optional.h"
184 185 186 187 188 189 190 191 192 193
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

194
#include "paddle/phi/api/lib/api_custom_impl.h"
195
#include "paddle/phi/api/lib/api_gen_utils.h"
196 197 198 199 200 201 202 203
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
H
hong 已提交
204
#include "paddle/phi/infermeta/ternary.h"
205 206

#include "paddle/fluid/platform/profiler/event_tracing.h"
Z
zyfncg 已提交
207 208

DECLARE_bool(conv2d_disable_cudnn);
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
"""


def api_namespace():
    return ("""
namespace paddle {
namespace experimental {

""", """

}  // namespace experimental
}  // namespace paddle
""")


224
def generate_api(api_yaml_path, header_file_path, source_file_path):
225 226 227 228 229 230 231
    apis = []

    for each_api_yaml in api_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                apis.extend(api_list)
232 233 234 235 236 237 238 239 240 241

    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

242
    include_header_file = "paddle/phi/api/include/api.h"
243 244 245 246
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for api in apis:
247 248
        foward_api = ForwardAPI(api)
        if foward_api.is_dygraph_api:
249
            foward_api.is_dygraph_api = False
250 251 252

        header_file.write(foward_api.gene_api_declaration())
        source_file.write(foward_api.gene_api_code())
253 254 255

    header_file.write(namespace[1])
    source_file.write(namespace[1])
256

257 258 259 260 261 262 263 264 265
    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
        description='Generate PaddlePaddle C++ API files')
    parser.add_argument(
        '--api_yaml_path',
266
        help='path to api yaml file',
267
        nargs='+',
268
        default='python/paddle/utils/code_gen/api.yaml')
269

270 271 272
    parser.add_argument(
        '--api_header_path',
        help='output of generated api header code file',
273
        default='paddle/phi/api/include/api.h')
274 275 276 277

    parser.add_argument(
        '--api_source_path',
        help='output of generated api source code file',
278
        default='paddle/phi/api/lib/api.cc')
279 280 281 282 283 284 285

    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    header_file_path = options.api_header_path
    source_file_path = options.api_source_path

286
    generate_api(api_yaml_path, header_file_path, source_file_path)
287 288 289 290


if __name__ == '__main__':
    main()