api_gen.py 15.6 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16
import re
17

18 19
import yaml
from api_base import PREFIX_TENSOR_NAME, BaseAPI
20

21 22
inplace_out_type_map = {
    "Tensor": "Tensor&",
23
    "std::vector<Tensor>": "std::vector<Tensor>&",
24 25
}

26 27
inplace_optional_out_type_map = {
    "Tensor": "paddle::optional<Tensor>&",
28
    "std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
29 30
}

31

32
class ForwardAPI(BaseAPI):
33
    def __init__(self, api_item_yaml):
34
        super().__init__(api_item_yaml)
35
        self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
36 37
            api_item_yaml
        )
38
        self.inplace_map, self.view_map = self.parse_inplace_and_view(
39 40
            api_item_yaml
        )
41 42 43 44 45 46 47

    def get_api_func_name(self):
        if self.is_dygraph_api:
            return self.api + '_intermediate'
        else:
            return self.api

Y
YuanRisheng 已提交
48 49 50
    def gene_input(self, kernel_tensor_type=None, code_indent=''):
        kernel_param = self.kernel['param']
        input_name_tensor_map, input_tensor_code = super().gene_input(
51 52
            kernel_tensor_type, code_indent
        )
Y
YuanRisheng 已提交
53 54 55

        # generate the input that is in view list
        for i, input_name in enumerate(self.inputs['names']):
56 57 58 59 60 61 62 63 64
            if (
                input_name in self.view_map.values()
                and input_name not in input_name_tensor_map.keys()
            ):
                if (
                    kernel_tensor_type is None
                    or kernel_tensor_type[0][kernel_param.index(input_name)]
                    == 'dense'
                ):
Y
YuanRisheng 已提交
65
                    trans_flag = self.gene_trans_flag(input_name)
66 67 68
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
Y
YuanRisheng 已提交
69
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag});"""
70
                    )
Y
YuanRisheng 已提交
71 72 73 74 75 76
                else:
                    # do nothing
                    pass

        return input_name_tensor_map, input_tensor_code

77 78
    def parse_intermediate(self, api_item_yaml):
        if 'intermediate' in api_item_yaml:
79 80 81 82 83
            intermediate_outs = [
                item.strip()
                for item in api_item_yaml['intermediate'].split(',')
            ]
            return True, intermediate_outs
84
        else:
85
            return False, []
86

87 88 89 90 91 92 93 94 95 96 97 98 99
    def parse_inplace_and_view(self, api_item_yaml):
        inplace_map, view_map = {}, {}
        for mode in ['inplace', 'view']:
            if mode in api_item_yaml:
                if mode == 'inplace':
                    inplace_map = {}
                else:
                    view_map = {}
                in_out_mapping_list = api_item_yaml[mode].split(',')
                for item in in_out_mapping_list:
                    result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
                    in_val = result.group('in')
                    out_val = result.group('out')
100 101 102 103 104 105
                    assert (
                        in_val in self.inputs['names']
                    ), f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
                    assert (
                        out_val in self.outputs['names']
                    ), f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
106 107 108 109 110 111 112 113

                    if mode == 'inplace':
                        inplace_map[out_val] = in_val
                    else:
                        view_map[out_val] = in_val

        return inplace_map, view_map

114 115 116 117 118
    def get_return_type_with_intermediate(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
119 120
                if self.inplace_map[out_name] in self.optional_vars:
                    out_type_list.append(
121 122
                        inplace_optional_out_type_map[out_type]
                    )
123 124
                else:
                    out_type_list.append(inplace_out_type_map[out_type])
125 126
            else:
                out_type_list.append(out_type)
127

128 129
        if len(out_type_list) == 1:
            return out_type_list[0]
130
        else:
131 132 133 134 135 136 137
            return "std::tuple<" + ", ".join(out_type_list) + ">"

    def get_return_type(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
138 139
                if self.inplace_map[out_name] in self.optional_vars:
                    out_type_list.append(
140 141
                        inplace_optional_out_type_map[out_type]
                    )
142 143
                else:
                    out_type_list.append(inplace_out_type_map[out_type])
144 145 146 147 148 149 150
            elif self.is_dygraph_api or out_name not in self.intermediate_outs:
                out_type_list.append(out_type)

        if len(out_type_list) == 1:
            return out_type_list[0]
        else:
            return "std::tuple<" + ", ".join(out_type_list) + ">"
151 152 153

    def gene_return_code(self):
        if self.is_dygraph_api or len(self.intermediate_outs) == 0:
154
            return "return api_output;"
155 156 157
        else:
            return_out_list = []
            for i, name in enumerate(self.outputs['names']):
158
                if name.split('@')[0] not in self.intermediate_outs:
159 160
                    return_out_list.append(i)
            if len(return_out_list) == 1:
161
                return f"return std::get<{return_out_list[0]}>(api_output);"
162 163 164 165
            else:
                selected_code = [
                    f"std::get<{i}>(api_output)" for i in return_out_list
                ]
166
            return 'return std::make_tuple(' + ", ".join(selected_code) + ');'
167

168 169 170 171 172 173 174
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
175
        kernel_output = []
176
        output_names = []
Z
zyfncg 已提交
177
        output_create = ""
178
        return_type = self.get_return_type_with_intermediate(inplace_flag)
Z
zyfncg 已提交
179

180
        if len(out_dtype_list) == 1:
181
            kernel_output.append('kernel_out')
182
            output_names.append('kernel_out')
183 184 185 186 187
            inplace_assign = (
                " = " + self.inplace_map[self.outputs['names'][0]]
                if inplace_flag and self.outputs['names'][0] in self.inplace_map
                else ""
            )
Z
zyfncg 已提交
188
            output_create = f"""
189
{code_indent}  {return_type} api_output{inplace_assign};"""
190 191 192 193 194 195
            set_out_func = (
                'SetKernelOutput'
                if out_tensor_type_list is None
                or out_tensor_type_list[0] == 'dense'
                else 'SetSelectedRowsKernelOutput'
            )
196
            if return_type == 'std::vector<Tensor>':
197 198 199 200 201 202
                assert (
                    self.outputs['out_size_expr'][0] is not None
                ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                output_create = (
                    output_create
                    + f"""
Z
zyfncg 已提交
203
{code_indent}  auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, &api_output);"""
204
                )
205 206

            else:
207 208 209
                output_create = (
                    output_create
                    + f"""
Z
zyfncg 已提交
210
{code_indent}  auto kernel_out = {set_out_func}(&api_output);"""
211 212 213 214 215 216 217 218 219 220
                )

            if (
                not inplace_flag
                and self.view_map is not None
                and self.outputs['names'][0] in self.view_map
            ):
                output_create = (
                    output_create
                    + f"""
221 222 223
{code_indent}  kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
224
                )
225

226
        elif len(out_dtype_list) > 1:
Z
zyfncg 已提交
227
            output_create = f"""
228 229 230 231 232 233 234 235
{code_indent}  {return_type} api_output;"""

            if inplace_flag:
                output_create = f"""
{code_indent}  {return_type} api_output{{"""

                for out_name in self.outputs['names']:
                    if out_name in self.inplace_map:
236
                        output_create += self.inplace_map[out_name] + ', '
237 238 239
                    else:
                        output_create += 'Tensor(), '
                output_create = output_create[:-2] + '};'
Z
zyfncg 已提交
240

241
            for i in range(len(out_dtype_list)):
242
                kernel_output.append(f'kernel_out_{i}')
243
                output_names.append(f'kernel_out_{i}')
244 245 246 247 248 249
                set_out_func = (
                    'SetKernelOutput'
                    if out_tensor_type_list is None
                    or out_tensor_type_list[i] == 'dense'
                    else 'SetSelectedRowsKernelOutput'
                )
250 251

                get_out_code = f"&std::get<{i}>(api_output)"
252 253 254 255 256
                if (
                    self.outputs['names'][i] in self.inplace_map
                    and self.inplace_map[self.outputs['names'][i]]
                    in self.optional_vars
                ):
257
                    get_out_code = f"std::get<{i}>(api_output).get_ptr()"
258

259
                if out_dtype_list[i] == 'std::vector<Tensor>':
260 261 262
                    assert (
                        self.outputs['out_size_expr'][i] is not None
                    ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
263 264 265
                    # Special case for inplace vector and inplace optional<vector>
                    if self.outputs['names'][i] in self.inplace_map:
                        set_out_func = "SetInplaceVectorKernelOutput"
266 267 268 269 270 271 272
                        if (
                            self.inplace_map[self.outputs['names'][i]]
                            in self.optional_vars
                        ):
                            set_out_func = (
                                "SetInplaceOptionalVectorKernelOutput"
                            )
273
                            get_out_code = f"std::get<{i}>(api_output)"
274 275 276
                    output_create = (
                        output_create
                        + f"""
Z
zyfncg 已提交
277
{code_indent}  auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
278
                    )
279 280

                else:
281 282 283
                    output_create = (
                        output_create
                        + f"""
Z
zyfncg 已提交
284
{code_indent}  auto kernel_out_{i} = {set_out_func}({get_out_code});"""
285
                    )
Z
zyfncg 已提交
286

287 288 289 290 291
                if (
                    not inplace_flag
                    and self.view_map is not None
                    and self.outputs['names'][i] in self.view_map
                ):
Y
YuanRisheng 已提交
292
                    if out_dtype_list[i] == 'Tensor':
293 294 295
                        output_create = (
                            output_create
                            + f"""
Y
YuanRisheng 已提交
296 297 298
    {code_indent}  kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
    {code_indent}  kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
    {code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
299
                        )
Y
YuanRisheng 已提交
300 301
                    else:
                        raise ValueError(
302 303 304 305
                            "{} : Output error: only support Tensor type when use view in yaml. But get {}".format(
                                self.api, out_dtype_list[i]
                            )
                        )
Z
zyfncg 已提交
306 307 308
        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
309 310 311
                    self.api
                )
            )
Z
zyfncg 已提交
312

313
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
314

315 316 317

def header_include():
    return """
318 319
#include <tuple>

320 321
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
322
#include "paddle/phi/common/int_array.h"
323
#include "paddle/utils/optional.h"
324 325 326 327 328 329 330 331 332 333
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

334
#include "paddle/phi/api/lib/api_custom_impl.h"
335
#include "paddle/phi/api/lib/api_gen_utils.h"
336 337
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
338
#include "paddle/phi/common/type_traits.h"
339 340 341 342 343
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
H
hong 已提交
344
#include "paddle/phi/infermeta/ternary.h"
345 346

#include "paddle/fluid/platform/profiler/event_tracing.h"
347
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
Z
zyfncg 已提交
348 349

DECLARE_bool(conv2d_disable_cudnn);
350 351 352 353
"""


def api_namespace():
354 355
    return (
        """
356 357 358
namespace paddle {
namespace experimental {

359 360
""",
        """
361 362 363

}  // namespace experimental
}  // namespace paddle
364 365
""",
    )
366 367


368
def generate_api(api_yaml_path, header_file_path, source_file_path):
369 370 371 372 373 374 375
    apis = []

    for each_api_yaml in api_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                apis.extend(api_list)
376 377 378 379 380 381 382 383 384 385

    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

386
    include_header_file = "paddle/phi/api/include/api.h"
387 388 389 390
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for api in apis:
391 392
        foward_api = ForwardAPI(api)
        if foward_api.is_dygraph_api:
393
            foward_api.is_dygraph_api = False
394 395 396

        header_file.write(foward_api.gene_api_declaration())
        source_file.write(foward_api.gene_api_code())
397 398 399

    header_file.write(namespace[1])
    source_file.write(namespace[1])
400

401 402 403 404 405 406
    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
407 408 409 410 411 412
        description='Generate PaddlePaddle C++ API files'
    )
    parser.add_argument(
        '--api_yaml_path',
        help='path to api yaml file',
        nargs='+',
413
        default=['paddle/phi/api/yaml/ops.yaml'],
414 415 416 417 418 419 420 421 422 423 424 425 426
    )

    parser.add_argument(
        '--api_header_path',
        help='output of generated api header code file',
        default='paddle/phi/api/include/api.h',
    )

    parser.add_argument(
        '--api_source_path',
        help='output of generated api source code file',
        default='paddle/phi/api/lib/api.cc',
    )
427 428 429 430 431 432 433

    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    header_file_path = options.api_header_path
    source_file_path = options.api_source_path

434
    generate_api(api_yaml_path, header_file_path, source_file_path)
435 436 437 438


if __name__ == '__main__':
    main()