api_gen.py 18.1 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16
import re
17

18 19
import yaml
from api_base import PREFIX_TENSOR_NAME, BaseAPI
20

21 22
inplace_out_type_map = {
    "Tensor": "Tensor&",
23
    "std::vector<Tensor>": "std::vector<Tensor>&",
24 25
}

26 27
inplace_optional_out_type_map = {
    "Tensor": "paddle::optional<Tensor>&",
28
    "std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
29 30
}

31

32
class ForwardAPI(BaseAPI):
33
    def __init__(self, api_item_yaml):
34
        super().__init__(api_item_yaml)
35
        self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
36 37
            api_item_yaml
        )
38
        self.inplace_map, self.view_map = self.parse_inplace_and_view(
39 40
            api_item_yaml
        )
41 42 43 44 45 46 47

    def get_api_func_name(self):
        if self.is_dygraph_api:
            return self.api + '_intermediate'
        else:
            return self.api

Y
YuanRisheng 已提交
48 49 50
    def gene_input(self, kernel_tensor_type=None, code_indent=''):
        kernel_param = self.kernel['param']
        input_name_tensor_map, input_tensor_code = super().gene_input(
51 52
            kernel_tensor_type, code_indent
        )
Y
YuanRisheng 已提交
53 54 55

        # generate the input that is in view list
        for i, input_name in enumerate(self.inputs['names']):
56 57 58 59 60 61 62 63 64
            if (
                input_name in self.view_map.values()
                and input_name not in input_name_tensor_map.keys()
            ):
                if (
                    kernel_tensor_type is None
                    or kernel_tensor_type[0][kernel_param.index(input_name)]
                    == 'dense'
                ):
Y
YuanRisheng 已提交
65
                    trans_flag = self.gene_trans_flag(input_name)
66 67 68
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
W
wanghuancoder 已提交
69
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag}, kernel_result.is_stride_kernel);"""
70
                    )
Y
YuanRisheng 已提交
71 72 73 74 75 76
                else:
                    # do nothing
                    pass

        return input_name_tensor_map, input_tensor_code

77 78
    def parse_intermediate(self, api_item_yaml):
        if 'intermediate' in api_item_yaml:
79 80 81 82 83
            intermediate_outs = [
                item.strip()
                for item in api_item_yaml['intermediate'].split(',')
            ]
            return True, intermediate_outs
84
        else:
85
            return False, []
86

87 88 89 90 91 92 93 94 95 96 97 98 99
    def parse_inplace_and_view(self, api_item_yaml):
        inplace_map, view_map = {}, {}
        for mode in ['inplace', 'view']:
            if mode in api_item_yaml:
                if mode == 'inplace':
                    inplace_map = {}
                else:
                    view_map = {}
                in_out_mapping_list = api_item_yaml[mode].split(',')
                for item in in_out_mapping_list:
                    result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
                    in_val = result.group('in')
                    out_val = result.group('out')
100 101 102 103 104 105
                    assert (
                        in_val in self.inputs['names']
                    ), f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
                    assert (
                        out_val in self.outputs['names']
                    ), f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
106 107 108 109 110 111 112 113

                    if mode == 'inplace':
                        inplace_map[out_val] = in_val
                    else:
                        view_map[out_val] = in_val

        return inplace_map, view_map

114 115 116 117 118
    def get_return_type_with_intermediate(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
119 120
                if self.inplace_map[out_name] in self.optional_vars:
                    out_type_list.append(
121 122
                        inplace_optional_out_type_map[out_type]
                    )
123 124
                else:
                    out_type_list.append(inplace_out_type_map[out_type])
125 126
            else:
                out_type_list.append(out_type)
127

128 129
        if len(out_type_list) == 1:
            return out_type_list[0]
130
        else:
131 132 133 134 135 136 137
            return "std::tuple<" + ", ".join(out_type_list) + ">"

    def get_return_type(self, inplace_flag=False):
        out_type_list = []
        for i, out_type in enumerate(self.outputs['types']):
            out_name = self.outputs['names'][i].split('@')[0]
            if inplace_flag and out_name in self.inplace_map:
138 139
                if self.inplace_map[out_name] in self.optional_vars:
                    out_type_list.append(
140 141
                        inplace_optional_out_type_map[out_type]
                    )
142 143
                else:
                    out_type_list.append(inplace_out_type_map[out_type])
144 145 146 147 148 149 150
            elif self.is_dygraph_api or out_name not in self.intermediate_outs:
                out_type_list.append(out_type)

        if len(out_type_list) == 1:
            return out_type_list[0]
        else:
            return "std::tuple<" + ", ".join(out_type_list) + ">"
151 152 153

    def gene_return_code(self):
        if self.is_dygraph_api or len(self.intermediate_outs) == 0:
154
            return "return api_output;"
155 156 157
        else:
            return_out_list = []
            for i, name in enumerate(self.outputs['names']):
158
                if name.split('@')[0] not in self.intermediate_outs:
159 160
                    return_out_list.append(i)
            if len(return_out_list) == 1:
161
                return f"return std::get<{return_out_list[0]}>(api_output);"
162 163 164 165
            else:
                selected_code = [
                    f"std::get<{i}>(api_output)" for i in return_out_list
                ]
166
            return 'return std::make_tuple(' + ", ".join(selected_code) + ');'
167

168 169 170 171 172 173 174
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
175
        kernel_output = []
176
        output_names = []
Z
zyfncg 已提交
177
        output_create = ""
178
        return_type = self.get_return_type_with_intermediate(inplace_flag)
Z
zyfncg 已提交
179

180
        if len(out_dtype_list) == 1:
181
            kernel_output.append('kernel_out')
182
            output_names.append('kernel_out')
183 184 185 186 187
            inplace_assign = (
                " = " + self.inplace_map[self.outputs['names'][0]]
                if inplace_flag and self.outputs['names'][0] in self.inplace_map
                else ""
            )
Z
zyfncg 已提交
188
            output_create = f"""
189
{code_indent}  {return_type} api_output{inplace_assign};"""
190 191 192 193 194 195
            set_out_func = (
                'SetKernelOutput'
                if out_tensor_type_list is None
                or out_tensor_type_list[0] == 'dense'
                else 'SetSelectedRowsKernelOutput'
            )
196
            if return_type == 'std::vector<Tensor>':
197 198 199 200 201 202
                assert (
                    self.outputs['out_size_expr'][0] is not None
                ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                output_create = (
                    output_create
                    + f"""
Z
zyfncg 已提交
203
{code_indent}  auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, &api_output);"""
204
                )
205 206

            else:
207 208 209
                output_create = (
                    output_create
                    + f"""
Z
zyfncg 已提交
210
{code_indent}  auto kernel_out = {set_out_func}(&api_output);"""
211 212 213 214 215 216 217 218 219 220
                )

            if (
                not inplace_flag
                and self.view_map is not None
                and self.outputs['names'][0] in self.view_map
            ):
                output_create = (
                    output_create
                    + f"""
221 222 223
{code_indent}  kernel_out->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
224
                )
225

226
        elif len(out_dtype_list) > 1:
Z
zyfncg 已提交
227
            output_create = f"""
228 229 230 231 232 233 234 235
{code_indent}  {return_type} api_output;"""

            if inplace_flag:
                output_create = f"""
{code_indent}  {return_type} api_output{{"""

                for out_name in self.outputs['names']:
                    if out_name in self.inplace_map:
236
                        output_create += self.inplace_map[out_name] + ', '
237 238 239
                    else:
                        output_create += 'Tensor(), '
                output_create = output_create[:-2] + '};'
Z
zyfncg 已提交
240

241
            for i in range(len(out_dtype_list)):
242
                kernel_output.append(f'kernel_out_{i}')
243
                output_names.append(f'kernel_out_{i}')
244 245 246 247 248 249
                set_out_func = (
                    'SetKernelOutput'
                    if out_tensor_type_list is None
                    or out_tensor_type_list[i] == 'dense'
                    else 'SetSelectedRowsKernelOutput'
                )
250 251

                get_out_code = f"&std::get<{i}>(api_output)"
252 253 254 255 256
                if (
                    self.outputs['names'][i] in self.inplace_map
                    and self.inplace_map[self.outputs['names'][i]]
                    in self.optional_vars
                ):
257
                    get_out_code = f"std::get<{i}>(api_output).get_ptr()"
258

259
                if out_dtype_list[i] == 'std::vector<Tensor>':
260 261 262
                    assert (
                        self.outputs['out_size_expr'][i] is not None
                    ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
263 264 265
                    # Special case for inplace vector and inplace optional<vector>
                    if self.outputs['names'][i] in self.inplace_map:
                        set_out_func = "SetInplaceVectorKernelOutput"
266 267 268 269 270 271 272
                        if (
                            self.inplace_map[self.outputs['names'][i]]
                            in self.optional_vars
                        ):
                            set_out_func = (
                                "SetInplaceOptionalVectorKernelOutput"
                            )
273
                            get_out_code = f"std::get<{i}>(api_output)"
274 275 276
                    output_create = (
                        output_create
                        + f"""
Z
zyfncg 已提交
277
{code_indent}  auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
278
                    )
279 280

                else:
281 282 283
                    output_create = (
                        output_create
                        + f"""
Z
zyfncg 已提交
284
{code_indent}  auto kernel_out_{i} = {set_out_func}({get_out_code});"""
285
                    )
Z
zyfncg 已提交
286

287 288 289 290 291
                if (
                    not inplace_flag
                    and self.view_map is not None
                    and self.outputs['names'][i] in self.view_map
                ):
Y
YuanRisheng 已提交
292
                    if out_dtype_list[i] == 'Tensor':
293 294 295
                        output_create = (
                            output_create
                            + f"""
Y
YuanRisheng 已提交
296 297 298
    {code_indent}  kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
    {code_indent}  kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
    {code_indent}  VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
299
                        )
Y
YuanRisheng 已提交
300 301
                    else:
                        raise ValueError(
302 303 304 305
                            "{} : Output error: only support Tensor type when use view in yaml. But get {}".format(
                                self.api, out_dtype_list[i]
                            )
                        )
Z
zyfncg 已提交
306 307 308
        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
309 310 311
                    self.api
                )
            )
Z
zyfncg 已提交
312

313
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
314

315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
    def reset_view_after_fallback(
        self, out_dtype_list, code_indent='', inplace_flag=False
    ):
        remap_code = ''

        if len(out_dtype_list) == 1:
            if (
                not inplace_flag
                and self.view_map is not None
                and self.outputs['names'][0] in self.view_map
            ):
                remap_code += f"""
{code_indent}    phi::DenseTensor * {self.view_map[self.outputs['names'][0]]}_remap = static_cast<phi::DenseTensor*>({self.view_map[self.outputs['names'][0]]}.impl().get());
{code_indent}    {self.view_map[self.outputs['names'][0]]}_remap->ShareBufferWith(*kernel_out);
{code_indent}    kernel_out->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][0]]}_remap);
"""
        elif len(out_dtype_list) > 1:
            for i in range(len(out_dtype_list)):
                if (
                    not inplace_flag
                    and self.view_map is not None
                    and self.outputs['names'][i] in self.view_map
                ):
                    remap_code += f"""
{code_indent}    phi::DenseTensor * {self.view_map[self.outputs['names'][i]]}_remap = static_cast<phi::DenseTensor*>({self.view_map[self.outputs['names'][i]]}.impl().get());
{code_indent}    {self.view_map[self.outputs['names'][i]]}_remap->ShareBufferWith(*kernel_out_{i});
{code_indent}    kernel_out_{i}->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][i]]}_remap);
"""
        return remap_code

345 346 347

def header_include():
    return """
348 349
#include <tuple>

350 351
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
352
#include "paddle/phi/common/int_array.h"
353
#include "paddle/utils/optional.h"
354 355 356 357 358 359 360 361 362
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"
363
#include "paddle/utils/flags.h"
364

365
#include "paddle/phi/api/lib/api_custom_impl.h"
366
#include "paddle/phi/api/lib/api_gen_utils.h"
367
#include "paddle/phi/api/lib/api_registry.h"
368
#include "paddle/phi/api/lib/data_transform.h"
H
Huang Jiyi 已提交
369
#include "paddle/phi/api/include/tensor_utils.h"
370
#include "paddle/phi/api/lib/kernel_dispatch.h"
371
#include "paddle/phi/common/type_traits.h"
372 373 374 375 376
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
H
hong 已提交
377
#include "paddle/phi/infermeta/ternary.h"
378

379
#include "paddle/phi/api/profiler/event_tracing.h"
380
#include "paddle/phi/api/profiler/supplement_tracing.h"
Z
zyfncg 已提交
381

382 383 384 385
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

386 387
PD_DECLARE_bool(conv2d_disable_cudnn);
PD_DECLARE_int32(low_precision_op_list);
388 389 390 391
"""


def api_namespace():
392 393
    return (
        """
394 395 396
namespace paddle {
namespace experimental {

397 398
""",
        """
399 400 401

}  // namespace experimental
}  // namespace paddle
402 403
""",
    )
404 405


406 407
def declare_extension_api():
    return """
H
Huang Jiyi 已提交
408
namespace paddle {
409
PD_DECLARE_API(from_blob);
H
Huang Jiyi 已提交
410
}  // namespace paddle
411 412 413
"""


414 415 416
def generate_api(
    api_yaml_path, is_fused_ops_yaml, header_file_path, source_file_path
):
417 418 419 420 421 422 423
    apis = []

    for each_api_yaml in api_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                apis.extend(api_list)
424 425 426 427 428 429 430 431 432 433

    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
    include_header_file = (
        "paddle/phi/api/include/fused_api.h"
        if is_fused_ops_yaml is True
        else "paddle/phi/api/include/api.h"
    )
    # not all fused ops supoort dygraph
    if is_fused_ops_yaml is True:
        new_apis = [
            api
            for api in apis
            if "support_dygraph_mode" in api
            and api["support_dygraph_mode"] is True
        ]
        apis = new_apis

449 450 451 452
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for api in apis:
453 454
        foward_api = ForwardAPI(api)
        if foward_api.is_dygraph_api:
455
            foward_api.is_dygraph_api = False
456 457 458

        header_file.write(foward_api.gene_api_declaration())
        source_file.write(foward_api.gene_api_code())
459 460 461

    header_file.write(namespace[1])
    source_file.write(namespace[1])
462

H
Huang Jiyi 已提交
463 464
    source_file.write(declare_extension_api())

465 466 467 468 469 470
    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
471 472 473 474 475 476
        description='Generate PaddlePaddle C++ API files'
    )
    parser.add_argument(
        '--api_yaml_path',
        help='path to api yaml file',
        nargs='+',
477
        default=['paddle/phi/api/yaml/ops.yaml'],
478 479
    )

480 481 482 483 484 485
    parser.add_argument(
        '--is_fused_ops_yaml',
        help='flag of fused ops yaml',
        action='store_true',
    )

486 487 488 489 490 491 492 493 494 495 496
    parser.add_argument(
        '--api_header_path',
        help='output of generated api header code file',
        default='paddle/phi/api/include/api.h',
    )

    parser.add_argument(
        '--api_source_path',
        help='output of generated api source code file',
        default='paddle/phi/api/lib/api.cc',
    )
497 498 499 500

    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
501
    is_fused_ops_yaml = options.is_fused_ops_yaml
502 503 504
    header_file_path = options.api_header_path
    source_file_path = options.api_source_path

505 506 507
    generate_api(
        api_yaml_path, is_fused_ops_yaml, header_file_path, source_file_path
    )
508 509 510 511


if __name__ == '__main__':
    main()