api_base.py 51.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
16
import collections
17

18
PREFIX_TENSOR_NAME = 'input_'
19 20 21
PREFIX_META_TENSOR_NAME = 'meta_'


22
class BaseAPI:
23 24 25 26 27 28 29 30 31 32 33 34
    def __init__(self, api_item_yaml):
        self.api = self.get_api_name(api_item_yaml)

        # inputs:
        #     names : [], list of input names
        #     input_info : {input_name : type}
        # attrs:
        #     names : [], list of attribute names
        #     attr_info : { attr_name : (type, default_values)}
        # outputs:
        #     names : [], list of output names
        #     types : [], list of output types
35
        #     out_size_expr : [], expression for getting size of vector<Tensor>
36 37 38 39 40 41
        (
            self.inputs,
            self.attrs,
            self.outputs,
            self.optional_vars,
        ) = self.parse_args(self.api, api_item_yaml)
42 43 44 45 46 47

        self.is_base_api = True
        if 'invoke' in api_item_yaml:
            self.is_base_api = False
            self.invoke = api_item_yaml['invoke']
        else:
48
            if 'infer_meta' in api_item_yaml:
49
                self.infer_meta = self.parse_infer_meta(
50 51
                    api_item_yaml['infer_meta']
                )
52 53
            self.kernel = self.parse_kernel(api_item_yaml['kernel'])
            self.data_transform = self.parse_data_transform(api_item_yaml)
54
            self.inplace_map, self.view_map = {}, {}
55

Y
YuanRisheng 已提交
56 57 58
        self.gene_input_func = {
            "const Tensor&": {
                "dense": self.gene_dense_input,
59
                "selected_rows": self.gene_selected_rows_input,
Y
YuanRisheng 已提交
60 61 62
            },
            "const paddle::optional<Tensor>&": {
                "dense": self.gene_dense_input,
63
                "selected_rows": self.gene_selected_rows_input,
Y
YuanRisheng 已提交
64
            },
65
            "const std::vector<Tensor>&": {"dense": self.gene_vec_dense_input},
Y
YuanRisheng 已提交
66 67
            "const paddle::optional<std::vector<Tensor>>&": {
                "dense": self.gene_optional_vec_dense_input
68
            },
Y
YuanRisheng 已提交
69 70
        }

71
    def get_api_name(self, api_item_yaml):
72
        return api_item_yaml['op']
73

74 75 76
    def get_api_func_name(self):
        return self.api

77 78 79
    def get_input_tensor_args(self, inplace_flag=False):
        input_args = []
        inplace_type_map = {
80 81 82 83
            "const Tensor&": "Tensor&",
            "const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
            "const std::vector<Tensor>&": "std::vector<Tensor>&",
            "const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
84 85 86 87
        }
        for name in self.inputs['names']:
            name = name.split('@')[0]
            if inplace_flag and name in self.inplace_map.values():
88
                input_args.append(
89 90 91 92
                    inplace_type_map[self.inputs['input_info'][name]]
                    + ' '
                    + name
                )
93 94 95 96 97 98 99 100 101 102
            else:
                input_args.append(self.inputs['input_info'][name] + ' ' + name)
        return input_args

    def get_declare_args(self, inplace_flag=False):
        declare_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            default_value = ''
            if self.attrs['attr_info'][name][1] is not None:
                default_value = ' = ' + self.attrs['attr_info'][name][1]
103 104 105
            declare_args.append(
                self.attrs['attr_info'][name][0] + ' ' + name + default_value
            )
106

107 108 109 110 111 112 113 114
        return ", ".join(declare_args)

    def get_define_args(self, inplace_flag=False):
        define_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)

        return ", ".join(define_args)
115

116
    def parse_args(self, api_name, api_item_yaml):
117 118 119 120 121
        optional_vars = []
        if 'optional' in api_item_yaml:
            optional_vars = [
                item.strip() for item in api_item_yaml['optional'].split(',')
            ]
122 123 124
        inputs, attrs = self.parse_input_and_attr(
            api_name, api_item_yaml['args'], optional_vars
        )
125
        output_type_list, output_names, out_size_expr = self.parse_output(
126 127 128 129 130 131 132 133 134 135 136 137
            api_name, api_item_yaml['output']
        )
        return (
            inputs,
            attrs,
            {
                'names': output_names,
                'types': output_type_list,
                'out_size_expr': out_size_expr,
            },
            optional_vars,
        )
138

139
    def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
140 141 142
        inputs = {'names': [], 'input_info': {}}
        attrs = {'names': [], 'attr_info': {}}
        args_str = args_config.strip()
143 144 145
        assert args_str.startswith('(') and args_str.endswith(
            ')'
        ), f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
146 147
        args_str = args_str[1:-1]
        args_list = args_str.split(',')
Z
zyfncg 已提交
148 149
        input_types_map = {
            'Tensor': 'const Tensor&',
150
            'Tensor[]': 'const std::vector<Tensor>&',
Z
zyfncg 已提交
151
        }
152
        attr_types_map = {
153
            'IntArray': 'const IntArray&',
154
            'Scalar': 'const Scalar&',
155 156 157 158
            'Scalar(int)': 'const Scalar&',
            'Scalar(int64_t)': 'const Scalar&',
            'Scalar(float)': 'const Scalar&',
            'Scalar(dobule)': 'const Scalar&',
159
            'Scalar[]': 'const std::vector<phi::Scalar>&',
160
            'int': 'int',
161 162
            'int32_t': 'int32_t',
            'int64_t': 'int64_t',
163 164 165
            'long': 'long',
            'size_t': 'size_t',
            'float': 'float',
166
            'float[]': 'const std::vector<float>&',
167 168
            'double': 'double',
            'bool': 'bool',
169
            'str': 'const std::string&',
170
            'str[]': 'const std::vector<std::string>&',
171
            'Place': 'const Place&',
172 173
            'DataLayout': 'DataLayout',
            'DataType': 'DataType',
174
            'int64_t[]': 'const std::vector<int64_t>&',
Z
zhiboniu 已提交
175
            'int[]': 'const std::vector<int>&',
176 177
        }
        optional_types_trans = {
178
            'Tensor': 'const paddle::optional<Tensor>&',
179 180
            'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
            'int': 'paddle::optional<int>',
181 182
            'int32_t': 'paddle::optional<int32_t>',
            'int64_t': 'paddle::optional<int64_t>',
183 184 185
            'float': 'paddle::optional<float>',
            'double': 'paddle::optional<double>',
            'bool': 'paddle::optional<bool>',
186
            'Place': 'paddle::optional<const Place&>',
187
            'DataLayout': 'paddle::optional<DataLayout>',
188
            'DataType': 'paddle::optional<DataType>',
189 190
        }

191 192
        for item in args_list:
            item = item.strip()
Z
zyfncg 已提交
193
            type_and_name = item.split(' ')
194 195
            # match the input tensor
            has_input = False
Z
zyfncg 已提交
196 197 198
            for in_type_symbol, in_type in input_types_map.items():
                if type_and_name[0] == in_type_symbol:
                    input_name = type_and_name[1].strip()
199 200 201 202 203 204
                    assert (
                        len(input_name) > 0
                    ), f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
                    assert (
                        len(attrs['names']) == 0
                    ), f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
205

206 207 208
                    if input_name in optional_vars:
                        in_type = optional_types_trans[in_type_symbol]

209 210 211 212 213 214 215 216
                    inputs['names'].append(input_name)
                    inputs['input_info'][input_name] = in_type
                    has_input = True
                    break
            if has_input:
                continue

            # match the attribute
Z
zyfncg 已提交
217 218
            for attr_type_symbol, attr_type in attr_types_map.items():
                if type_and_name[0] == attr_type_symbol:
219 220 221 222
                    attr_name = item[len(attr_type_symbol) :].strip()
                    assert (
                        len(attr_name) > 0
                    ), f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
223 224 225 226 227 228
                    default_value = None
                    if '=' in attr_name:
                        attr_infos = attr_name.split('=')
                        attr_name = attr_infos[0].strip()
                        default_value = attr_infos[1].strip()

229 230 231
                    if attr_name in optional_vars:
                        attr_type = optional_types_trans[attr_type_symbol]

232 233 234
                    default_value_str = (
                        "" if default_value is None else '=' + default_value
                    )
235 236 237 238
                    attrs['names'].append(attr_name)
                    attrs['attr_info'][attr_name] = (attr_type, default_value)
                    break

239
        return inputs, attrs
240 241 242

    def parse_output(self, api_name, output_config):
        def parse_output_item(output_item):
Z
zyfncg 已提交
243 244
            output_type_map = {
                'Tensor': 'Tensor',
245
                'Tensor[]': 'std::vector<Tensor>',
Z
zyfncg 已提交
246
            }
247 248
            result = re.search(
                r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
249 250 251 252 253
                output_item,
            )
            assert (
                result is not None
            ), f"{api_name} : the output config parse error."
254
            out_type = result.group('out_type')
255 256 257
            assert (
                out_type in output_type_map
            ), f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
258 259
                  but now is {out_type}."

260 261 262 263 264 265 266 267 268 269
            out_name = (
                'out'
                if result.group('name') is None
                else result.group('name')[1:-1]
            )
            out_size_expr = (
                None
                if result.group('expr') is None
                else result.group('expr')[1:-1]
            )
270
            return output_type_map[out_type], out_name, out_size_expr
271 272 273 274

        temp_list = output_config.split(',')

        if len(temp_list) == 1:
275
            out_type, out_name, size_expr = parse_output_item(temp_list[0])
276
            return [out_type], [out_name], [size_expr]
277 278 279
        else:
            out_type_list = []
            out_name_list = []
280
            out_size_expr_list = []
281
            for output_item in temp_list:
282
                out_type, out_name, size_expr = parse_output_item(output_item)
283 284
                out_type_list.append(out_type)
                out_name_list.append(out_name)
285
                out_size_expr_list.append(size_expr)
286

287
            return out_type_list, out_name_list, out_size_expr_list
288

289 290 291 292 293 294 295 296 297 298 299 300 301 302
    def parse_infer_meta(self, infer_meta_config):
        infer_meta = infer_meta_config
        if 'param' not in infer_meta_config:
            infer_meta['param'] = None

        return infer_meta

    def parse_kernel(self, kernel_config):
        # kernel :
        #    func : [], Kernel functions (example: scale, scale_sr)
        #    param : [], Input params of kernel
        #    backend : str, the names of param to choose the kernel backend, default is None
        #    layout : str, the names of param to choose the kernel layout, default is None
        #    data_type : str, the names of param to choose the kernel data_type, default is None
303
        #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
304 305 306 307 308
        kernel = {
            'func': [],
            'param': None,
            'backend': None,
            'layout': None,
Z
zyfncg 已提交
309
            'data_type': None,
310
            'dispatch': {},
311 312 313 314 315 316 317 318 319
        }
        if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
            kernel['backend'] = kernel_config['backend']
        if 'layout' in kernel_config and len(kernel_config['layout']) > 0:
            kernel['layout'] = kernel_config['layout']
        if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0:
            kernel['data_type'] = kernel_config['data_type']
        if 'param' in kernel_config:
            kernel['param'] = kernel_config['param']
320
        kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
321 322
            kernel_config['func']
        )
323 324 325 326 327 328 329

        def parse_kernel_in_out_type(in_out_str):
            if len(in_out_str) == 0:
                return None
            tmp_in_out_list = in_out_str[1:-1].split('->')
            inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
            outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
330 331 332 333

            # check the tensor type
            for item in inputs:
                assert item in [
334 335 336 337
                    'dense',
                    'selected_rows',
                    'sparse_coo',
                    'sparse_csr',
338 339 340
                ], f"{self.api} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
            for item in outputs:
                assert item in [
341 342 343 344
                    'dense',
                    'selected_rows',
                    'sparse_coo',
                    'sparse_csr',
345 346
                ], f"{self.api} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."

347 348 349 350 351
            return (inputs, outputs)

        for func_item in kernel_funcs:
            kernel['func'].append(func_item[0])
            kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
352 353
                func_item[1]
            )
354 355 356 357 358 359 360 361

        return kernel

    def parse_data_transform(self, api_item_yaml):
        data_transform = {'skip_transform': [], 'support_trans_dtype': []}
        if 'data_transform' in api_item_yaml:
            if 'skip_transform' in api_item_yaml['data_transform']:
                data_transform['skip_transform'] = api_item_yaml[
362 363
                    'data_transform'
                ]['skip_transform']
364 365
            if 'support_trans_dtype' in api_item_yaml['data_transform']:
                data_transform['support_trans_dtype'] = api_item_yaml[
366 367
                    'data_transform'
                ]['support_trans_dtype']
368 369 370

        return data_transform

371
    # Override by child class
372
    def get_return_type(self, inplace_flag=False):
373 374 375
        return None

    def gene_api_declaration(self):
376 377 378 379 380
        api_declaration = ""
        api_func_name = self.get_api_func_name()
        if api_func_name[-1] != '_':
            api_declaration = f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
381 382
"""

383 384 385
        if self.is_base_api and len(self.inplace_map) > 0:
            if api_func_name[-1] != '_':
                api_func_name += '_'
386 387 388
            api_declaration = (
                api_declaration
                + f"""
389
PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
390
"""
391
            )
392 393 394

        return api_declaration

395 396 397 398 399 400
    # Backward API Override this method
    def gene_kernel_backend_select(self):
        backend_select_code = ""
        if self.kernel['backend'] is not None:
            if '>' in self.kernel['backend']:
                vars_list = self.kernel['backend'].split('>')
401 402 403 404 405 406 407
                assert (
                    len(vars_list) == 2
                ), f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
                assert (vars_list[0].strip() in self.attrs['names']) and (
                    self.attrs['attr_info'][vars_list[0].strip()][0]
                    == 'const Place&'
                ), f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
408 409 410 411 412 413 414 415 416 417 418 419 420 421
                backend_select_code = f"""
  kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                backend_args = [
                    ele.strip() for ele in self.kernel['backend'].split(',')
                ]
                backend_select_code = f"""
  kernel_backend = ParseBackend({", ".join(backend_args)});
"""

        return backend_select_code

422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
    def gene_kernel_select(self) -> str:
        api = self.api
        input_names = self.inputs['names']
        attrs = self.attrs
        kernel = self.kernel

        kernel_key_item_init = """
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;
"""
        # Check the tensor options
        attr_backend_count = 0
        attr_layout_count = 0
        attr_data_type_count = 0
        for attr_name in attrs['names']:
438
            if attrs['attr_info'][attr_name][0] == 'const Place&':
439 440 441
                assert (
                    kernel['backend'] is not None
                ), f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
442 443
                attr_backend_count = attr_backend_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataLayout':
444 445 446
                assert (
                    kernel['layout'] is not None
                ), f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
447 448
                attr_layout_count = attr_layout_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataType':
449 450 451
                assert (
                    kernel['data_type'] is not None
                ), f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
452 453 454
                attr_data_type_count = attr_data_type_count + 1

        # preprocess kernel configures
455
        kernel_select_code = self.gene_kernel_backend_select()
456 457 458 459

        if kernel['layout'] is not None:
            if '>' in kernel['layout']:
                vars_list = kernel['layout'].split('>')
460 461 462 463 464 465 466 467 468 469 470
                assert (
                    len(vars_list) == 2
                ), f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}."
                assert (
                    vars_list[0].strip() in attrs['names']
                    and attrs['attr_info'][vars_list[0].strip()][0]
                    == 'DataLayout'
                ), f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
471 472
  kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
473
                )
474 475 476

            else:
                vars_list = kernel['layout'].split(',')
477 478 479 480 481 482
                assert (
                    len(vars_list) == 1
                ), f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
483 484
  kernel_layout = ParseLayout({vars_list[0].strip()});
"""
485
                )
486 487 488 489

        if kernel['data_type'] is not None:
            if '>' in kernel['data_type']:
                vars_list = kernel['data_type'].split('>')
490 491 492 493 494 495 496 497 498 499 500
                assert (
                    len(vars_list) == 2
                ), f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}."
                assert (
                    vars_list[0].strip() in attrs['names']
                    and attrs['attr_info'][vars_list[0].strip()][0]
                    == 'DataType'
                ), f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
501 502
  kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
503
                )
504 505 506

            else:
                vars_list = kernel['data_type'].split(',')
507 508 509 510 511 512
                assert (
                    len(vars_list) == 1
                ), f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
513 514
  kernel_data_type = ParseDataType({vars_list[0].strip()});
"""
515
                )
516 517

        if len(input_names) == 0:
518 519 520
            assert (
                attr_backend_count > 0 and attr_data_type_count > 0
            ), f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'."
521 522 523 524 525 526 527 528 529 530 531

        kernel_select_args = ""
        for input_name in input_names:
            kernel_select_args = kernel_select_args + input_name + ", "

        if len(kernel_select_args) > 2:
            kernel_select_args = kernel_select_args[:-2]

        kernel_select_code = kernel_key_item_init + kernel_select_code

        if len(input_names) > 0:
532 533 534
            kernel_select_code = (
                kernel_select_code
                + f"""
535 536 537 538
  if (kernel_backend == Backend::UNDEFINED
        || kernel_layout == DataLayout::UNDEFINED
        || kernel_data_type == DataType::UNDEFINED ) {{
    auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
539
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
540 541 542 543 544 545 546 547 548 549
    if (kernel_backend == Backend::UNDEFINED) {{
      kernel_backend = kernel_key.backend();
    }}
    if (kernel_layout == DataLayout::UNDEFINED) {{
      kernel_layout = kernel_key.layout();
    }}
    if (kernel_data_type == DataType::UNDEFINED) {{
      kernel_data_type = kernel_key.dtype();
    }}
  }}"""
550
            )
551 552 553

        return kernel_select_code

554
    def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
555 556 557 558
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        infer_meta = self.infer_meta

559 560 561 562 563
        infer_meta_params = (
            infer_meta['param']
            if infer_meta['param'] is not None
            else input_names + attr_names
        )
564 565 566 567 568
        # generate meta tensors
        meta_tensor_code = ""
        param_code = ""
        for param in infer_meta_params:
            if param in input_names:
569
                if self.inputs['input_info'][param] == "const Tensor&":
570 571 572 573 574 575 576 577 578 579 580 581 582 583
                    param_code = (
                        param_code
                        + "MakeMetaTensor(*"
                        + PREFIX_TENSOR_NAME
                        + param
                        + "), "
                    )
                elif (
                    self.inputs['input_info'][param]
                    == "const std::vector<Tensor>&"
                ):
                    meta_tensor_code = (
                        meta_tensor_code
                        + f"""
584
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
585
{code_indent}  std::vector<const phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
586 587 588 589
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas[i] = &{param}_meta_vec[i];
{code_indent}  }}
"""
590
                    )
591
                    param_code = param_code + param + "_metas, "
592 593 594 595 596 597 598
                elif (
                    self.inputs['input_info'][param]
                    == "const paddle::optional<std::vector<Tensor>>&"
                ):
                    meta_tensor_code = (
                        meta_tensor_code
                        + f"""
599 600 601 602 603 604
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent}  paddle::optional<std::vector<const phi::MetaTensor*>> {param}_metas({param}_meta_vec.size());
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas->at(i) = &{param}_meta_vec[i];
{code_indent}  }}
"""
605
                    )
606 607
                    param_code = param_code + param + "_metas, "
                elif param in self.optional_vars:
608 609 610 611 612 613 614
                    param_code = (
                        param_code
                        + "MakeMetaTensor("
                        + PREFIX_TENSOR_NAME
                        + param
                        + "), "
                    )
615
                else:
616 617 618
                    raise ValueError(
                        f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
                    )
619 620 621 622 623 624 625 626 627
            elif param in attr_names:
                param_code = param_code + param + ", "
            elif isinstance(param, str):
                param_code = param_code + "\"" + param + "\", "
            elif isinstance(param, bool):
                param_code = param_code + str(param).lower() + ", "
            else:
                param_code = param_code + str(param) + ", "

628 629
        for i, out_name in enumerate(kernel_output_names):
            if self.outputs['types'][i] == 'std::vector<Tensor>':
630 631 632
                meta_tensor_code = (
                    meta_tensor_code
                    + f"""
633 634 635
{code_indent}  auto {out_name}_{PREFIX_META_TENSOR_NAME}vec = MakeMetaTensor({out_name});
{code_indent}  std::vector<phi::MetaTensor*> {out_name}_metas({out_name}_{PREFIX_META_TENSOR_NAME}vec.size());
{code_indent}  for (size_t i = 0; i < {out_name}_{PREFIX_META_TENSOR_NAME}vec.size(); ++i) {{
636
{code_indent}    {out_name}_metas[i] = {out_name}[i] ? &{out_name}_{PREFIX_META_TENSOR_NAME}vec[i] : nullptr;
637
{code_indent}  }}"""
638
                )
639 640 641

                param_code = param_code + out_name + '_metas, '
            else:
642 643 644 645 646 647 648 649 650
                meta_tensor_code = (
                    meta_tensor_code
                    + code_indent
                    + "  phi::MetaTensor "
                    + out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)
                    + "("
                    + out_name
                    + ");\n"
                )
651
                if len(kernel_output_names) == 1:
652 653 654 655
                    param_code = (
                        param_code
                        + f"&{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)}, "
                    )
656
                else:
657 658 659 660
                    param_code = (
                        param_code
                        + f"{out_name} ? &{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)} : nullptr, "
                    )
661

662 663
        param_code = param_code[:-2]
        return f"""{meta_tensor_code}
664
{code_indent}  phi::{infer_meta['func']}({param_code});
665 666
"""

Y
YuanRisheng 已提交
667 668 669 670 671 672 673 674
    def gene_trans_flag(self, input_name):
        trans_flag = "{}"
        if input_name in self.data_transform['skip_transform']:
            trans_flag = "{true}"
        elif input_name in self.data_transform['support_trans_dtype']:
            trans_flag = "{false, true}"
        return trans_flag

675 676 677
    def gene_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
678 679
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
680
        input_names = self.inputs['names']
Y
YuanRisheng 已提交
681 682 683 684 685 686
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map[input_name].append(
687 688 689 690 691
            (f"{PREFIX_TENSOR_NAME}{input_name}", False)
        )
        input_tensor_code = (
            input_tensor_code
            + f"""
Y
YuanRisheng 已提交
692
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
693
        )
Y
YuanRisheng 已提交
694
        return input_tensor_code
695

696 697 698
    def gene_selected_rows_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
699 700 701
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
702 703 704 705 706
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

Y
YuanRisheng 已提交
707
        input_name_tensor_map[input_name].append(
708 709 710 711 712
            (f"{PREFIX_TENSOR_NAME}{input_name}", False)
        )
        input_tensor_code = (
            input_tensor_code
            + f"""
713
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareDataForSelectedRows({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
Y
YuanRisheng 已提交
714
"""
715
        )
Y
YuanRisheng 已提交
716 717
        return input_tensor_code

718 719 720
    def gene_optional_vec_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
721
        input_tensor_code = ""
Y
YuanRisheng 已提交
722 723 724 725 726 727 728 729
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
730 731 732 733 734
                (f"{PREFIX_TENSOR_NAME}{input_name}", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
735
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
736
            )
Y
YuanRisheng 已提交
737 738
        else:
            input_name_tensor_map[input_name].append(
739 740 741 742 743
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
744 745 746 747 748 749 750 751
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name};
{code_indent}  if ({PREFIX_TENSOR_NAME}{input_name}_vec){{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name} = paddle::optional<std::vector<const phi::DenseTensor*>>({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}    for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}_vec->size(); ++i) {{
{code_indent}      {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}    }}
{code_indent}  }}"""
752
            )
Y
YuanRisheng 已提交
753
        return input_tensor_code
754

755 756 757
    def gene_vec_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
758 759 760 761 762 763 764
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
765

Y
YuanRisheng 已提交
766 767
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
768 769 770 771 772
                (f"{PREFIX_TENSOR_NAME}{input_name}", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
773
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
774
            )
Y
YuanRisheng 已提交
775 776
        else:
            input_name_tensor_map[input_name].append(
777 778 779 780 781
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
782
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
783 784 785 786
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}  for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}  }}"""
787
            )
Y
YuanRisheng 已提交
788
        return input_tensor_code
789

Y
YuanRisheng 已提交
790 791 792 793 794 795 796 797 798 799 800 801 802
    def gene_input(self, kernel_tensor_type=None, code_indent=''):
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        input_name_tensor_map = collections.defaultdict(list)
        input_tensor_code = ""
        for i, input_name in enumerate(input_names):
            # set input code
            if input_name in kernel_param:
                # input is dense tensor
                api_tensor_type = self.inputs['input_info'][input_name]
803 804 805 806 807
                phi_tensor_type = (
                    'dense'
                    if kernel_tensor_type is None
                    else kernel_tensor_type[0][kernel_param.index(input_name)]
                )
Y
YuanRisheng 已提交
808 809
                if api_tensor_type in self.gene_input_func.keys():
                    input_tensor_code += self.gene_input_func[api_tensor_type][
810 811
                        phi_tensor_type
                    ](input_name, input_name_tensor_map, code_indent)
Y
YuanRisheng 已提交
812 813 814
                else:
                    # do nothing
                    pass
815 816 817
            else:
                if input_name in self.infer_meta['param']:
                    if input_name in self.optional_vars:
818 819 820
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
821
{code_indent}  paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
822
                        )
823

824
                    else:
825 826 827 828 829 830 831
                        if (
                            self.inputs['input_info'][input_name]
                            == "const std::vector<Tensor>&"
                        ):
                            input_tensor_code = (
                                input_tensor_code
                                + f"""
832 833
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_uq_ptr = TensorToDenseTensor({input_name});
{code_indent}  const auto& {PREFIX_TENSOR_NAME}{input_name} = *{PREFIX_TENSOR_NAME}{input_name}_uq_ptr;"""
834
                            )
835
                        else:
836 837 838
                            input_tensor_code = (
                                input_tensor_code
                                + f"""
839
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
840
                            )
Y
YuanRisheng 已提交
841 842 843 844 845

        return input_name_tensor_map, input_tensor_code

    def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
        dense_input_trans_map = {
846 847 848 849 850
            'const Tensor&': 'const phi::DenseTensor&',
            'const std::vector<Tensor>&': 'const std::vector<const phi::DenseTensor*>&',
            'const paddle::optional<Tensor&>': 'paddle::optional<const phi::DenseTensor&>',
            'const paddle::optional<Tensor>&': 'const paddle::optional<phi::DenseTensor>&',
            'const paddle::optional<std::vector<Tensor>>&': 'const paddle::optional<std::vector<const phi::DenseTensor*>>&',
Y
YuanRisheng 已提交
851 852 853
        }
        dense_out_trans_map = {
            'Tensor': 'phi::DenseTensor*',
854
            'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&',
Y
YuanRisheng 已提交
855 856
        }
        sr_input_trans_map = {
857 858
            'const Tensor&': 'const phi::SelectedRows&',
            'const paddle::optional<Tensor>&': 'const paddle::optional<phi::SelectedRows>&',
Y
YuanRisheng 已提交
859 860 861 862 863 864 865 866 867 868 869 870
        }
        sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
        input_names = self.inputs['names']
        input_infos = self.inputs['input_info']
        kernel_args_type_list = ['const platform::DeviceContext&']

        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map, input_tensor_code = self.gene_input(
871 872
            kernel_tensor_type, code_indent
        )
Y
YuanRisheng 已提交
873

874 875 876
        input_tensor_code = (
            input_tensor_code
            + f"""
877
{code_indent}  if(platform::RecordOpInfoSupplement::IsEnabled()){{"""
878
        )
879 880 881 882 883 884 885 886 887 888 889 890
        single_tensor_names = []
        list_tensor_names = []
        for input_name, input_tensors in input_name_tensor_map.items():
            has_vector_tensor = False
            for input_tensor, is_vector in input_tensors:
                if is_vector is True:
                    has_vector_tensor = True
            if has_vector_tensor is False:
                single_tensor_names.append(input_name)
            else:
                list_tensor_names.append(input_name)
        if not single_tensor_names:
891 892 893
            input_tensor_code = (
                input_tensor_code
                + f"""
894
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
895
            )
896
        else:
897 898 899
            for input_name in single_tensor_names:
                if input_name in self.optional_vars:
                    input_tensors = input_name_tensor_map[input_name]
900 901 902
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
903
{code_indent}     std::vector<phi::DDim> {input_name}_record_shapes;"""
904
                    )
905
                    for input_tensor, _ in input_tensors:
906 907 908
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
909 910 911
{code_indent}     if({input_tensor}){{
{code_indent}       {input_name}_record_shapes.push_back((*{input_tensor}).dims());
{code_indent}     }}"""
912
                        )
913

914 915 916
            input_tensor_code = (
                input_tensor_code
                + f"""
917
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
918
            )
919
            for input_name in single_tensor_names[:-1]:
920
                if input_name in self.optional_vars:
921 922 923
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
924
{code_indent}     {{"{input_name}", {input_name}_record_shapes}},"""
925
                    )
926
                else:
927 928 929
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
930
{code_indent}     {{"{input_name}", {{"""
931
                    )
932 933
                    input_tensors = input_name_tensor_map[input_name]
                    for input_tensor, _ in input_tensors[:-1]:
934 935 936
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
937
{code_indent}     (*{input_tensor}).dims(),"""
938 939 940 941
                        )
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
942
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}},"""
943
                    )
944
            if single_tensor_names[-1] in self.optional_vars:
945 946 947
                input_tensor_code = (
                    input_tensor_code
                    + f"""
948
{code_indent}     {{"{single_tensor_names[-1]}",
949
{code_indent}     {single_tensor_names[-1]}_record_shapes}}}};"""
950
                )
951
            else:
952 953 954
                input_tensor_code = (
                    input_tensor_code
                    + f"""
955
{code_indent}     {{"{single_tensor_names[-1]}", {{"""
956
                )
957 958
                input_tensors = input_name_tensor_map[single_tensor_names[-1]]
                for input_tensor, _ in input_tensors[:-1]:
959 960 961
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
962
{code_indent}     (*{input_tensor}).dims(),"""
963 964 965 966
                    )
                input_tensor_code = (
                    input_tensor_code
                    + f"""
967
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}}}};"""
968
                )
969
        if list_tensor_names:
970 971 972
            input_tensor_code = (
                input_tensor_code
                + f"""
973
{code_indent}     std::vector<phi::DDim> ddims_vec;"""
974
            )
975
        for input_name in list_tensor_names:
976 977 978
            input_tensor_code = (
                input_tensor_code
                + f"""
979
{code_indent}     ddims_vec.clear();"""
980
            )
981 982
            for input_tensor, is_vector in input_name_tensor_map[input_name]:
                if is_vector:
983 984 985 986
                    input_tensor_truncate = input_tensor[:-4]
                    if input_name in self.inplace_map.values():
                        input_tensor_truncate = input_tensor

987
                    if input_name in self.optional_vars:
988 989 990
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
991 992 993 994
{code_indent}     if ({input_tensor_truncate}){{
{code_indent}       ddims_vec.reserve({input_tensor_truncate}->size());
{code_indent}       for (size_t i = 0; i < {input_tensor_truncate}->size(); ++i) {{
{code_indent}         ddims_vec.emplace_back((*{input_tensor_truncate}->at(i)).dims());
995 996
{code_indent}       }}
{code_indent}     }}"""
997
                        )
998
                    else:
999 1000 1001
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
1002 1003 1004
{code_indent}     ddims_vec.reserve({input_tensor_truncate}.size());
{code_indent}     for (size_t i = 0; i < {input_tensor_truncate}.size(); ++i) {{
{code_indent}       ddims_vec.emplace_back((*{input_tensor_truncate}[i]).dims());
1005
{code_indent}     }}"""
1006
                        )
1007
                else:
1008 1009 1010
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
1011 1012
                  ddims_vec.emplace_back((*{input_tensor}).dims());
{code_indent}     """
1013 1014 1015 1016
                    )
            input_tensor_code = (
                input_tensor_code
                + f"""
1017
{code_indent}     input_shapes.emplace_back("{input_name}", ddims_vec);"""
1018
            )
1019

1020 1021 1022
        input_tensor_code = (
            input_tensor_code
            + f"""
1023
{code_indent}     platform::RecordOpInfoSupplement("{self.api}", input_shapes);
1024
{code_indent}  }}"""
1025
        )
1026
        kernel_args = ["*dev_ctx"]
1027 1028
        for param in kernel_param:
            if param in input_names:
1029
                if param in self.optional_vars:
1030
                    kernel_args.append(PREFIX_TENSOR_NAME + param)
1031
                else:
1032
                    if self.inputs['input_info'][param] == "const Tensor&":
1033
                        kernel_args.append("*" + PREFIX_TENSOR_NAME + param)
1034 1035 1036 1037
                    elif (
                        self.inputs['input_info'][param]
                        == "const std::vector<Tensor>&"
                    ):
1038
                        kernel_args.append(PREFIX_TENSOR_NAME + param)
1039 1040 1041
                    else:
                        # do nothing
                        pass
1042
                # input is dense tensor
1043 1044 1045 1046 1047
                if (
                    kernel_tensor_type is None
                    or kernel_tensor_type[0][kernel_param.index(param)]
                    == 'dense'
                ):
1048
                    kernel_args_type_list.append(
1049 1050
                        dense_input_trans_map[input_infos[param]]
                    )
1051 1052
                else:  # input is selected_rows
                    kernel_args_type_list.append(
1053 1054
                        sr_input_trans_map[input_infos[param]]
                    )
1055 1056
            elif param in attr_names:
                # set attr for kernel_context
1057 1058 1059
                if 'IntArray' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append('const phi::IntArray&')
                    param = 'phi::IntArray(' + param + ')'
1060 1061
                elif 'vector<phi::Scalar>' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append(
1062 1063
                        'const std::vector<phi::Scalar>&'
                    )
1064
                    param = param
1065
                elif 'Scalar' in self.attrs['attr_info'][param][0]:
1066 1067
                    kernel_args_type_list.append('const phi::Scalar&')
                    param = 'phi::Scalar(' + param + ')'
1068
                else:
1069
                    kernel_args_type_list.append(
1070 1071
                        self.attrs['attr_info'][param][0]
                    )
1072
                kernel_args.append(param)
1073
            elif isinstance(param, bool):
1074
                kernel_args.append(str(param).lower())
1075
            else:
1076
                kernel_args.append(str(param))
1077

1078 1079
        for i, out_type in enumerate(self.outputs['types']):
            # output is dense tensor
1080 1081 1082 1083
            if (
                kernel_tensor_type is None
                or kernel_tensor_type[1][i] == 'dense'
            ):
1084 1085 1086
                kernel_args_type_list.append(dense_out_trans_map[out_type])
            else:  # output is selected_rows
                kernel_args_type_list.append(sr_out_trans_map[out_type])
1087 1088 1089

        kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

1090
        return input_tensor_code, ", ".join(kernel_args), kernel_signature
1091

1092 1093
    # Override by child class
    def gene_return_code(self):
1094
        return "return api_output;"
1095

1096
    # Override by child class
1097 1098 1099 1100 1101 1102 1103
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
1104 1105
        return None, None, None

1106 1107
    def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
        kernel_dispatch = self.kernel['dispatch'][kernel_name]
1108
        input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
1109 1110
            kernel_dispatch, code_indent
        )
1111
        out_tensor_type_list = kernel_dispatch[1] if kernel_dispatch else None
1112
        outputs_args, kernel_output_names, output_create = self.gene_output(
1113 1114 1115 1116 1117
            self.outputs['types'],
            out_tensor_type_list,
            code_indent,
            inplace_flag,
        )
1118 1119
        fallback_kernel_output_trans = ""
        for kernel_out in outputs_args:
1120
            fallback_kernel_output_trans += f"""
1121
{code_indent}    TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""
1122
        return f"""
F
From00 已提交
1123
{code_indent}  VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
1124
{code_indent}  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
1125
{code_indent}      "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
1126
{code_indent}  const auto& kernel = kernel_result.kernel;
1127
{code_indent}  VLOG(6) << "{kernel_name} kernel: " << kernel;
1128
{code_indent}  auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
1129 1130
{input_tensors}
{output_create}
1131 1132 1133 1134
{code_indent}  paddle::platform::RecordEvent *infer_shape_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    infer_shape_record_event = new paddle::platform::RecordEvent(\"{self.api} infer_meta\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
1135
{self.gene_infer_meta(kernel_output_names, code_indent)}
1136 1137 1138
{code_indent}  if(infer_shape_record_event != nullptr){{
{code_indent}    delete infer_shape_record_event;
{code_indent}  }}
1139 1140
{code_indent}  using kernel_signature = {kernel_signature};
{code_indent}  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
1141 1142 1143 1144
{code_indent}  paddle::platform::RecordEvent* kernel_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    kernel_record_event = new paddle::platform::RecordEvent(\"{self.api} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
1145
{code_indent}    (*kernel_fn)({kernel_args}, {", ".join(outputs_args)});
1146 1147
{code_indent}  if(kernel_record_event != nullptr){{
{code_indent}    delete kernel_record_event;
1148 1149 1150
{code_indent}  }}
{code_indent}  if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
1151
{code_indent}  }}
1152
{code_indent}  {self.gene_return_code()}"""
1153

1154
    def get_condition_code(self, kernel_name):
1155 1156 1157
        assert self.kernel['dispatch'][
            kernel_name
        ], f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'scale' in ops.yaml."
1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
        input_types = self.kernel['dispatch'][kernel_name][0]
        condition_list = []
        for i, in_type in enumerate(input_types):
            if in_type == "dense":
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_dense_tensor())"
                    )
                else:
                    condition_list.append(
1168 1169
                        f"{self.inputs['names'][i]}.is_dense_tensor()"
                    )
1170 1171 1172 1173 1174 1175 1176
            else:
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_selected_rows())"
                    )
                else:
                    condition_list.append(
1177 1178
                        f"{self.inputs['names'][i]}.is_selected_rows()"
                    )
1179
        return " && ".join(condition_list)
1180

1181 1182 1183 1184 1185 1186
    def gene_dispatch_code(self, kernel_name, inplace_flag=False):
        return f"""
  if ({self.get_condition_code(kernel_name)}) {{
{self.gen_kernel_code(kernel_name, '  ', inplace_flag)}
  }}
"""
1187

1188
    def gene_base_api_code(self, inplace_flag=False):
1189 1190 1191
        api_func_name = self.get_api_func_name()
        if inplace_flag and api_func_name[-1] != '_':
            api_func_name += '_'
1192
        api_code = f"""
1193
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
1194
{self.gene_kernel_select()}
1195
"""
1196

1197 1198 1199 1200
        if len(self.kernel['func']) > 1:
            kernel_dispatch_code = ''
            for kernel_name in self.kernel['func']:
                kernel_dispatch_code += self.gene_dispatch_code(
1201 1202 1203 1204 1205
                    kernel_name, inplace_flag
                )
            return (
                api_code
                + f"""
1206 1207 1208
{kernel_dispatch_code}
  PADDLE_THROW(phi::errors::Unimplemented(
          "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
1209
}}
1210
"""
1211
            )
1212
        else:
1213 1214 1215 1216
            return (
                api_code
                + self.gen_kernel_code(self.kernel['func'][0], '', inplace_flag)
                + """
1217
}
1218
"""
1219
            )
1220

1221 1222
    def gene_invoke_code(self, invoke_code, params_code):
        return f"""
1223
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
1224 1225 1226
  return {invoke_code};
}}"""

1227 1228 1229
    def gene_api_code(self):
        if self.is_base_api:
            api_code = self.gene_base_api_code()
1230
            if len(self.inplace_map) > 0:
Z
zyfncg 已提交
1231 1232
                if self.api[-1] == '_':
                    api_code = ""
1233 1234 1235
                api_code = api_code + self.gene_base_api_code(inplace_flag=True)
            return api_code

1236
        else:
1237 1238
            invoke_func_name = self.invoke.split('(')[0].strip()
            if invoke_func_name in self.attrs['names']:
1239
                # Adjust the param whose name is same with api invoked.
1240
                pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
1241 1242 1243 1244 1245 1246

                def adjust_name(matched):
                    matched_str = matched.group()
                    return matched_str[0:-1] + '_val' + matched_str[-1]

                invoke_code = re.sub(pattern, adjust_name, self.invoke)
1247 1248 1249
                params_code = re.sub(
                    pattern, adjust_name, self.get_define_args()
                )
1250 1251
            else:
                invoke_code = self.invoke
1252 1253
                params_code = self.get_define_args()
            return self.gene_invoke_code(invoke_code, params_code)