api_base.py 51.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
16
import collections
17

18
PREFIX_TENSOR_NAME = 'input_'
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
PREFIX_META_TENSOR_NAME = 'meta_'


class BaseAPI(object):
    def __init__(self, api_item_yaml):
        self.api = self.get_api_name(api_item_yaml)

        # inputs:
        #     names : [], list of input names
        #     input_info : {input_name : type}
        # attrs:
        #     names : [], list of attribute names
        #     attr_info : { attr_name : (type, default_values)}
        # outputs:
        #     names : [], list of output names
        #     types : [], list of output types
35
        #     out_size_expr : [], expression for getting size of vector<Tensor>
36 37 38 39 40 41
        (
            self.inputs,
            self.attrs,
            self.outputs,
            self.optional_vars,
        ) = self.parse_args(self.api, api_item_yaml)
42 43 44 45 46 47

        self.is_base_api = True
        if 'invoke' in api_item_yaml:
            self.is_base_api = False
            self.invoke = api_item_yaml['invoke']
        else:
48
            if 'infer_meta' in api_item_yaml:
49
                self.infer_meta = self.parse_infer_meta(
50 51
                    api_item_yaml['infer_meta']
                )
52 53
            self.kernel = self.parse_kernel(api_item_yaml['kernel'])
            self.data_transform = self.parse_data_transform(api_item_yaml)
54
            self.inplace_map, self.view_map = {}, {}
55

Y
YuanRisheng 已提交
56 57 58
        self.gene_input_func = {
            "const Tensor&": {
                "dense": self.gene_dense_input,
59
                "selected_rows": self.gene_selected_rows_input,
Y
YuanRisheng 已提交
60 61 62
            },
            "const paddle::optional<Tensor>&": {
                "dense": self.gene_dense_input,
63
                "selected_rows": self.gene_selected_rows_input,
Y
YuanRisheng 已提交
64
            },
65
            "const std::vector<Tensor>&": {"dense": self.gene_vec_dense_input},
Y
YuanRisheng 已提交
66 67
            "const paddle::optional<std::vector<Tensor>>&": {
                "dense": self.gene_optional_vec_dense_input
68
            },
Y
YuanRisheng 已提交
69 70
        }

71
    def get_api_name(self, api_item_yaml):
72
        return api_item_yaml['op']
73

74 75 76
    def get_api_func_name(self):
        return self.api

77 78 79
    def get_input_tensor_args(self, inplace_flag=False):
        input_args = []
        inplace_type_map = {
80 81 82 83
            "const Tensor&": "Tensor&",
            "const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
            "const std::vector<Tensor>&": "std::vector<Tensor>&",
            "const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
84 85 86 87
        }
        for name in self.inputs['names']:
            name = name.split('@')[0]
            if inplace_flag and name in self.inplace_map.values():
88
                input_args.append(
89 90 91 92
                    inplace_type_map[self.inputs['input_info'][name]]
                    + ' '
                    + name
                )
93 94 95 96 97 98 99 100 101 102
            else:
                input_args.append(self.inputs['input_info'][name] + ' ' + name)
        return input_args

    def get_declare_args(self, inplace_flag=False):
        declare_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            default_value = ''
            if self.attrs['attr_info'][name][1] is not None:
                default_value = ' = ' + self.attrs['attr_info'][name][1]
103 104 105
            declare_args.append(
                self.attrs['attr_info'][name][0] + ' ' + name + default_value
            )
106

107 108 109 110 111 112 113 114
        return ", ".join(declare_args)

    def get_define_args(self, inplace_flag=False):
        define_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)

        return ", ".join(define_args)
115

116
    def parse_args(self, api_name, api_item_yaml):
117 118 119 120 121
        optional_vars = []
        if 'optional' in api_item_yaml:
            optional_vars = [
                item.strip() for item in api_item_yaml['optional'].split(',')
            ]
122 123 124
        inputs, attrs = self.parse_input_and_attr(
            api_name, api_item_yaml['args'], optional_vars
        )
125
        output_type_list, output_names, out_size_expr = self.parse_output(
126 127 128 129 130 131 132 133 134 135 136 137
            api_name, api_item_yaml['output']
        )
        return (
            inputs,
            attrs,
            {
                'names': output_names,
                'types': output_type_list,
                'out_size_expr': out_size_expr,
            },
            optional_vars,
        )
138

139
    def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
140 141 142
        inputs = {'names': [], 'input_info': {}}
        attrs = {'names': [], 'attr_info': {}}
        args_str = args_config.strip()
143 144 145
        assert args_str.startswith('(') and args_str.endswith(
            ')'
        ), f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
146 147
        args_str = args_str[1:-1]
        args_list = args_str.split(',')
Z
zyfncg 已提交
148 149
        input_types_map = {
            'Tensor': 'const Tensor&',
150
            'Tensor[]': 'const std::vector<Tensor>&',
Z
zyfncg 已提交
151
        }
152
        attr_types_map = {
153
            'IntArray': 'const IntArray&',
154
            'Scalar': 'const Scalar&',
155 156 157 158
            'Scalar(int)': 'const Scalar&',
            'Scalar(int64_t)': 'const Scalar&',
            'Scalar(float)': 'const Scalar&',
            'Scalar(dobule)': 'const Scalar&',
159
            'Scalar[]': 'const std::vector<phi::Scalar>&',
160
            'int': 'int',
161 162
            'int32_t': 'int32_t',
            'int64_t': 'int64_t',
163 164 165
            'long': 'long',
            'size_t': 'size_t',
            'float': 'float',
166
            'float[]': 'const std::vector<float>&',
167 168
            'double': 'double',
            'bool': 'bool',
169
            'str': 'const std::string&',
170
            'str[]': 'const std::vector<std::string>&',
171
            'Place': 'const Place&',
172 173
            'DataLayout': 'DataLayout',
            'DataType': 'DataType',
174
            'int64_t[]': 'const std::vector<int64_t>&',
Z
zhiboniu 已提交
175
            'int[]': 'const std::vector<int>&',
176 177
        }
        optional_types_trans = {
178
            'Tensor': 'const paddle::optional<Tensor>&',
179 180
            'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
            'int': 'paddle::optional<int>',
181 182
            'int32_t': 'paddle::optional<int32_t>',
            'int64_t': 'paddle::optional<int64_t>',
183 184 185
            'float': 'paddle::optional<float>',
            'double': 'paddle::optional<double>',
            'bool': 'paddle::optional<bool>',
186
            'Place': 'paddle::optional<const Place&>',
187
            'DataLayout': 'paddle::optional<DataLayout>',
188
            'DataType': 'paddle::optional<DataType>',
189 190
        }

191 192
        for item in args_list:
            item = item.strip()
Z
zyfncg 已提交
193
            type_and_name = item.split(' ')
194 195
            # match the input tensor
            has_input = False
Z
zyfncg 已提交
196 197 198
            for in_type_symbol, in_type in input_types_map.items():
                if type_and_name[0] == in_type_symbol:
                    input_name = type_and_name[1].strip()
199 200 201 202 203 204
                    assert (
                        len(input_name) > 0
                    ), f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
                    assert (
                        len(attrs['names']) == 0
                    ), f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
205

206 207 208
                    if input_name in optional_vars:
                        in_type = optional_types_trans[in_type_symbol]

209 210 211 212 213 214 215 216
                    inputs['names'].append(input_name)
                    inputs['input_info'][input_name] = in_type
                    has_input = True
                    break
            if has_input:
                continue

            # match the attribute
Z
zyfncg 已提交
217 218
            for attr_type_symbol, attr_type in attr_types_map.items():
                if type_and_name[0] == attr_type_symbol:
219 220 221 222
                    attr_name = item[len(attr_type_symbol) :].strip()
                    assert (
                        len(attr_name) > 0
                    ), f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
223 224 225 226 227 228
                    default_value = None
                    if '=' in attr_name:
                        attr_infos = attr_name.split('=')
                        attr_name = attr_infos[0].strip()
                        default_value = attr_infos[1].strip()

229 230 231
                    if attr_name in optional_vars:
                        attr_type = optional_types_trans[attr_type_symbol]

232 233 234
                    default_value_str = (
                        "" if default_value is None else '=' + default_value
                    )
235 236 237 238
                    attrs['names'].append(attr_name)
                    attrs['attr_info'][attr_name] = (attr_type, default_value)
                    break

239
        return inputs, attrs
240 241 242

    def parse_output(self, api_name, output_config):
        def parse_output_item(output_item):
Z
zyfncg 已提交
243 244
            output_type_map = {
                'Tensor': 'Tensor',
245
                'Tensor[]': 'std::vector<Tensor>',
Z
zyfncg 已提交
246
            }
247 248
            result = re.search(
                r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
249 250 251 252 253
                output_item,
            )
            assert (
                result is not None
            ), f"{api_name} : the output config parse error."
254
            out_type = result.group('out_type')
255 256 257
            assert (
                out_type in output_type_map
            ), f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
258 259
                  but now is {out_type}."

260 261 262 263 264 265 266 267 268 269
            out_name = (
                'out'
                if result.group('name') is None
                else result.group('name')[1:-1]
            )
            out_size_expr = (
                None
                if result.group('expr') is None
                else result.group('expr')[1:-1]
            )
270
            return output_type_map[out_type], out_name, out_size_expr
271 272 273 274

        temp_list = output_config.split(',')

        if len(temp_list) == 1:
275
            out_type, out_name, size_expr = parse_output_item(temp_list[0])
276
            return [out_type], [out_name], [size_expr]
277 278 279
        else:
            out_type_list = []
            out_name_list = []
280
            out_size_expr_list = []
281
            for output_item in temp_list:
282
                out_type, out_name, size_expr = parse_output_item(output_item)
283 284
                out_type_list.append(out_type)
                out_name_list.append(out_name)
285
                out_size_expr_list.append(size_expr)
286

287
            return out_type_list, out_name_list, out_size_expr_list
288

289 290 291 292 293 294 295 296 297 298 299 300 301 302
    def parse_infer_meta(self, infer_meta_config):
        infer_meta = infer_meta_config
        if 'param' not in infer_meta_config:
            infer_meta['param'] = None

        return infer_meta

    def parse_kernel(self, kernel_config):
        # kernel :
        #    func : [], Kernel functions (example: scale, scale_sr)
        #    param : [], Input params of kernel
        #    backend : str, the names of param to choose the kernel backend, default is None
        #    layout : str, the names of param to choose the kernel layout, default is None
        #    data_type : str, the names of param to choose the kernel data_type, default is None
303
        #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
304 305 306 307 308
        kernel = {
            'func': [],
            'param': None,
            'backend': None,
            'layout': None,
Z
zyfncg 已提交
309
            'data_type': None,
310
            'use_gpudnn': 'false',
311
            'dispatch': {},
312 313 314 315 316 317 318 319 320
        }
        if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
            kernel['backend'] = kernel_config['backend']
        if 'layout' in kernel_config and len(kernel_config['layout']) > 0:
            kernel['layout'] = kernel_config['layout']
        if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0:
            kernel['data_type'] = kernel_config['data_type']
        if 'param' in kernel_config:
            kernel['param'] = kernel_config['param']
321 322 323 324
        if 'use_gpudnn' in kernel_config:
            kernel['use_gpudnn'] = kernel_config['use_gpudnn']
            if isinstance(kernel['use_gpudnn'], bool):
                kernel['use_gpudnn'] = str(kernel['use_gpudnn']).lower()
325
        kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
326 327
            kernel_config['func']
        )
328 329 330 331 332 333 334

        def parse_kernel_in_out_type(in_out_str):
            if len(in_out_str) == 0:
                return None
            tmp_in_out_list = in_out_str[1:-1].split('->')
            inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
            outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
335 336 337 338

            # check the tensor type
            for item in inputs:
                assert item in [
339 340 341 342
                    'dense',
                    'selected_rows',
                    'sparse_coo',
                    'sparse_csr',
343 344 345
                ], f"{self.api} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
            for item in outputs:
                assert item in [
346 347 348 349
                    'dense',
                    'selected_rows',
                    'sparse_coo',
                    'sparse_csr',
350 351
                ], f"{self.api} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."

352 353 354 355 356
            return (inputs, outputs)

        for func_item in kernel_funcs:
            kernel['func'].append(func_item[0])
            kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
357 358
                func_item[1]
            )
359 360 361 362 363 364 365 366

        return kernel

    def parse_data_transform(self, api_item_yaml):
        data_transform = {'skip_transform': [], 'support_trans_dtype': []}
        if 'data_transform' in api_item_yaml:
            if 'skip_transform' in api_item_yaml['data_transform']:
                data_transform['skip_transform'] = api_item_yaml[
367 368
                    'data_transform'
                ]['skip_transform']
369 370
            if 'support_trans_dtype' in api_item_yaml['data_transform']:
                data_transform['support_trans_dtype'] = api_item_yaml[
371 372
                    'data_transform'
                ]['support_trans_dtype']
373 374 375

        return data_transform

376
    # Override by child class
377
    def get_return_type(self, inplace_flag=False):
378 379 380
        return None

    def gene_api_declaration(self):
381 382 383 384 385
        api_declaration = ""
        api_func_name = self.get_api_func_name()
        if api_func_name[-1] != '_':
            api_declaration = f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
386 387
"""

388 389 390
        if self.is_base_api and len(self.inplace_map) > 0:
            if api_func_name[-1] != '_':
                api_func_name += '_'
391 392 393
            api_declaration = (
                api_declaration
                + f"""
394
PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
395
"""
396
            )
397 398 399

        return api_declaration

400 401 402 403 404 405
    # Backward API Override this method
    def gene_kernel_backend_select(self):
        backend_select_code = ""
        if self.kernel['backend'] is not None:
            if '>' in self.kernel['backend']:
                vars_list = self.kernel['backend'].split('>')
406 407 408 409 410 411 412
                assert (
                    len(vars_list) == 2
                ), f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
                assert (vars_list[0].strip() in self.attrs['names']) and (
                    self.attrs['attr_info'][vars_list[0].strip()][0]
                    == 'const Place&'
                ), f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
413 414 415 416 417 418 419 420 421 422 423 424 425 426
                backend_select_code = f"""
  kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                backend_args = [
                    ele.strip() for ele in self.kernel['backend'].split(',')
                ]
                backend_select_code = f"""
  kernel_backend = ParseBackend({", ".join(backend_args)});
"""

        return backend_select_code

427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
    def gene_kernel_select(self) -> str:
        api = self.api
        input_names = self.inputs['names']
        attrs = self.attrs
        kernel = self.kernel

        kernel_key_item_init = """
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;
"""
        # Check the tensor options
        attr_backend_count = 0
        attr_layout_count = 0
        attr_data_type_count = 0
        for attr_name in attrs['names']:
443
            if attrs['attr_info'][attr_name][0] == 'const Place&':
444 445 446
                assert (
                    kernel['backend'] is not None
                ), f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
447 448
                attr_backend_count = attr_backend_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataLayout':
449 450 451
                assert (
                    kernel['layout'] is not None
                ), f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
452 453
                attr_layout_count = attr_layout_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataType':
454 455 456
                assert (
                    kernel['data_type'] is not None
                ), f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
457 458 459
                attr_data_type_count = attr_data_type_count + 1

        # preprocess kernel configures
460
        kernel_select_code = self.gene_kernel_backend_select()
461 462 463 464

        if kernel['layout'] is not None:
            if '>' in kernel['layout']:
                vars_list = kernel['layout'].split('>')
465 466 467 468 469 470 471 472 473 474 475
                assert (
                    len(vars_list) == 2
                ), f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}."
                assert (
                    vars_list[0].strip() in attrs['names']
                    and attrs['attr_info'][vars_list[0].strip()][0]
                    == 'DataLayout'
                ), f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
476 477
  kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
478
                )
479 480 481

            else:
                vars_list = kernel['layout'].split(',')
482 483 484 485 486 487
                assert (
                    len(vars_list) == 1
                ), f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
488 489
  kernel_layout = ParseLayout({vars_list[0].strip()});
"""
490
                )
491 492 493 494

        if kernel['data_type'] is not None:
            if '>' in kernel['data_type']:
                vars_list = kernel['data_type'].split('>')
495 496 497 498 499 500 501 502 503 504 505
                assert (
                    len(vars_list) == 2
                ), f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}."
                assert (
                    vars_list[0].strip() in attrs['names']
                    and attrs['attr_info'][vars_list[0].strip()][0]
                    == 'DataType'
                ), f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
506 507
  kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
508
                )
509 510 511

            else:
                vars_list = kernel['data_type'].split(',')
512 513 514 515 516 517
                assert (
                    len(vars_list) == 1
                ), f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
                kernel_select_code = (
                    kernel_select_code
                    + f"""
518 519
  kernel_data_type = ParseDataType({vars_list[0].strip()});
"""
520
                )
521 522

        if len(input_names) == 0:
523 524 525
            assert (
                attr_backend_count > 0 and attr_data_type_count > 0
            ), f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'."
526 527 528 529 530 531 532 533 534 535 536

        kernel_select_args = ""
        for input_name in input_names:
            kernel_select_args = kernel_select_args + input_name + ", "

        if len(kernel_select_args) > 2:
            kernel_select_args = kernel_select_args[:-2]

        kernel_select_code = kernel_key_item_init + kernel_select_code

        if len(input_names) > 0:
537 538 539
            kernel_select_code = (
                kernel_select_code
                + f"""
540 541 542 543
  if (kernel_backend == Backend::UNDEFINED
        || kernel_layout == DataLayout::UNDEFINED
        || kernel_data_type == DataType::UNDEFINED ) {{
    auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
544
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
545 546 547 548 549 550 551 552 553 554
    if (kernel_backend == Backend::UNDEFINED) {{
      kernel_backend = kernel_key.backend();
    }}
    if (kernel_layout == DataLayout::UNDEFINED) {{
      kernel_layout = kernel_key.layout();
    }}
    if (kernel_data_type == DataType::UNDEFINED) {{
      kernel_data_type = kernel_key.dtype();
    }}
  }}"""
555
            )
556 557 558

        return kernel_select_code

559
    def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
560 561 562 563
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        infer_meta = self.infer_meta

564 565 566 567 568
        infer_meta_params = (
            infer_meta['param']
            if infer_meta['param'] is not None
            else input_names + attr_names
        )
569 570 571 572 573
        # generate meta tensors
        meta_tensor_code = ""
        param_code = ""
        for param in infer_meta_params:
            if param in input_names:
574
                if self.inputs['input_info'][param] == "const Tensor&":
575 576 577 578 579 580 581 582 583 584 585 586 587 588
                    param_code = (
                        param_code
                        + "MakeMetaTensor(*"
                        + PREFIX_TENSOR_NAME
                        + param
                        + "), "
                    )
                elif (
                    self.inputs['input_info'][param]
                    == "const std::vector<Tensor>&"
                ):
                    meta_tensor_code = (
                        meta_tensor_code
                        + f"""
589
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
590
{code_indent}  std::vector<const phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
591 592 593 594
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas[i] = &{param}_meta_vec[i];
{code_indent}  }}
"""
595
                    )
596
                    param_code = param_code + param + "_metas, "
597 598 599 600 601 602 603
                elif (
                    self.inputs['input_info'][param]
                    == "const paddle::optional<std::vector<Tensor>>&"
                ):
                    meta_tensor_code = (
                        meta_tensor_code
                        + f"""
604 605 606 607 608 609
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent}  paddle::optional<std::vector<const phi::MetaTensor*>> {param}_metas({param}_meta_vec.size());
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas->at(i) = &{param}_meta_vec[i];
{code_indent}  }}
"""
610
                    )
611 612
                    param_code = param_code + param + "_metas, "
                elif param in self.optional_vars:
613 614 615 616 617 618 619
                    param_code = (
                        param_code
                        + "MakeMetaTensor("
                        + PREFIX_TENSOR_NAME
                        + param
                        + "), "
                    )
620
                else:
621 622 623
                    raise ValueError(
                        f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
                    )
624 625 626 627 628 629 630 631 632
            elif param in attr_names:
                param_code = param_code + param + ", "
            elif isinstance(param, str):
                param_code = param_code + "\"" + param + "\", "
            elif isinstance(param, bool):
                param_code = param_code + str(param).lower() + ", "
            else:
                param_code = param_code + str(param) + ", "

633 634
        for i, out_name in enumerate(kernel_output_names):
            if self.outputs['types'][i] == 'std::vector<Tensor>':
635 636 637
                meta_tensor_code = (
                    meta_tensor_code
                    + f"""
638 639 640
{code_indent}  auto {out_name}_{PREFIX_META_TENSOR_NAME}vec = MakeMetaTensor({out_name});
{code_indent}  std::vector<phi::MetaTensor*> {out_name}_metas({out_name}_{PREFIX_META_TENSOR_NAME}vec.size());
{code_indent}  for (size_t i = 0; i < {out_name}_{PREFIX_META_TENSOR_NAME}vec.size(); ++i) {{
641
{code_indent}    {out_name}_metas[i] = {out_name}[i] ? &{out_name}_{PREFIX_META_TENSOR_NAME}vec[i] : nullptr;
642
{code_indent}  }}"""
643
                )
644 645 646

                param_code = param_code + out_name + '_metas, '
            else:
647 648 649 650 651 652 653 654 655
                meta_tensor_code = (
                    meta_tensor_code
                    + code_indent
                    + "  phi::MetaTensor "
                    + out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)
                    + "("
                    + out_name
                    + ");\n"
                )
656
                if len(kernel_output_names) == 1:
657 658 659 660
                    param_code = (
                        param_code
                        + f"&{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)}, "
                    )
661
                else:
662 663 664 665
                    param_code = (
                        param_code
                        + f"{out_name} ? &{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)} : nullptr, "
                    )
666

667 668
        param_code = param_code[:-2]
        return f"""{meta_tensor_code}
669
{code_indent}  phi::{infer_meta['func']}({param_code});
670 671
"""

Y
YuanRisheng 已提交
672 673 674 675 676 677 678 679
    def gene_trans_flag(self, input_name):
        trans_flag = "{}"
        if input_name in self.data_transform['skip_transform']:
            trans_flag = "{true}"
        elif input_name in self.data_transform['support_trans_dtype']:
            trans_flag = "{false, true}"
        return trans_flag

680 681 682
    def gene_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
683 684
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
685
        input_names = self.inputs['names']
Y
YuanRisheng 已提交
686 687 688 689 690 691
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map[input_name].append(
692 693 694 695 696
            (f"{PREFIX_TENSOR_NAME}{input_name}", False)
        )
        input_tensor_code = (
            input_tensor_code
            + f"""
Y
YuanRisheng 已提交
697
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
698
        )
Y
YuanRisheng 已提交
699
        return input_tensor_code
700

701 702 703
    def gene_selected_rows_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
704 705 706
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
707 708 709 710 711
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

Y
YuanRisheng 已提交
712
        input_name_tensor_map[input_name].append(
713 714 715 716 717
            (f"{PREFIX_TENSOR_NAME}{input_name}", False)
        )
        input_tensor_code = (
            input_tensor_code
            + f"""
Y
YuanRisheng 已提交
718 719
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});
"""
720
        )
Y
YuanRisheng 已提交
721 722
        return input_tensor_code

723 724 725
    def gene_optional_vec_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
726
        input_tensor_code = ""
Y
YuanRisheng 已提交
727 728 729 730 731 732 733 734
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
735 736 737 738 739
                (f"{PREFIX_TENSOR_NAME}{input_name}", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
740
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
741
            )
Y
YuanRisheng 已提交
742 743
        else:
            input_name_tensor_map[input_name].append(
744 745 746 747 748
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
749 750 751 752 753 754 755 756
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name};
{code_indent}  if ({PREFIX_TENSOR_NAME}{input_name}_vec){{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name} = paddle::optional<std::vector<const phi::DenseTensor*>>({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}    for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}_vec->size(); ++i) {{
{code_indent}      {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}    }}
{code_indent}  }}"""
757
            )
Y
YuanRisheng 已提交
758
        return input_tensor_code
759

760 761 762
    def gene_vec_dense_input(
        self, input_name, input_name_tensor_map, code_indent=''
    ):
Y
YuanRisheng 已提交
763 764 765 766 767 768 769
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
770

Y
YuanRisheng 已提交
771 772
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
773 774 775 776 777
                (f"{PREFIX_TENSOR_NAME}{input_name}", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
778
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
779
            )
Y
YuanRisheng 已提交
780 781
        else:
            input_name_tensor_map[input_name].append(
782 783 784 785 786
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)
            )
            input_tensor_code = (
                input_tensor_code
                + f"""
787
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
788 789 790 791
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}  for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}  }}"""
792
            )
Y
YuanRisheng 已提交
793
        return input_tensor_code
794

Y
YuanRisheng 已提交
795 796 797 798 799 800 801 802 803 804 805 806 807
    def gene_input(self, kernel_tensor_type=None, code_indent=''):
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        input_name_tensor_map = collections.defaultdict(list)
        input_tensor_code = ""
        for i, input_name in enumerate(input_names):
            # set input code
            if input_name in kernel_param:
                # input is dense tensor
                api_tensor_type = self.inputs['input_info'][input_name]
808 809 810 811 812
                phi_tensor_type = (
                    'dense'
                    if kernel_tensor_type is None
                    else kernel_tensor_type[0][kernel_param.index(input_name)]
                )
Y
YuanRisheng 已提交
813 814
                if api_tensor_type in self.gene_input_func.keys():
                    input_tensor_code += self.gene_input_func[api_tensor_type][
815 816
                        phi_tensor_type
                    ](input_name, input_name_tensor_map, code_indent)
Y
YuanRisheng 已提交
817 818 819
                else:
                    # do nothing
                    pass
820 821 822
            else:
                if input_name in self.infer_meta['param']:
                    if input_name in self.optional_vars:
823 824 825
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
826
{code_indent}  paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
827
                        )
828

829
                    else:
830 831 832 833 834 835 836
                        if (
                            self.inputs['input_info'][input_name]
                            == "const std::vector<Tensor>&"
                        ):
                            input_tensor_code = (
                                input_tensor_code
                                + f"""
837 838
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_uq_ptr = TensorToDenseTensor({input_name});
{code_indent}  const auto& {PREFIX_TENSOR_NAME}{input_name} = *{PREFIX_TENSOR_NAME}{input_name}_uq_ptr;"""
839
                            )
840
                        else:
841 842 843
                            input_tensor_code = (
                                input_tensor_code
                                + f"""
844
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
845
                            )
Y
YuanRisheng 已提交
846 847 848 849 850

        return input_name_tensor_map, input_tensor_code

    def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
        dense_input_trans_map = {
851 852 853 854 855
            'const Tensor&': 'const phi::DenseTensor&',
            'const std::vector<Tensor>&': 'const std::vector<const phi::DenseTensor*>&',
            'const paddle::optional<Tensor&>': 'paddle::optional<const phi::DenseTensor&>',
            'const paddle::optional<Tensor>&': 'const paddle::optional<phi::DenseTensor>&',
            'const paddle::optional<std::vector<Tensor>>&': 'const paddle::optional<std::vector<const phi::DenseTensor*>>&',
Y
YuanRisheng 已提交
856 857 858
        }
        dense_out_trans_map = {
            'Tensor': 'phi::DenseTensor*',
859
            'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&',
Y
YuanRisheng 已提交
860 861
        }
        sr_input_trans_map = {
862 863
            'const Tensor&': 'const phi::SelectedRows&',
            'const paddle::optional<Tensor>&': 'const paddle::optional<phi::SelectedRows>&',
Y
YuanRisheng 已提交
864 865 866 867 868 869 870 871 872 873 874 875
        }
        sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
        input_names = self.inputs['names']
        input_infos = self.inputs['input_info']
        kernel_args_type_list = ['const platform::DeviceContext&']

        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map, input_tensor_code = self.gene_input(
876 877
            kernel_tensor_type, code_indent
        )
Y
YuanRisheng 已提交
878

879 880 881
        input_tensor_code = (
            input_tensor_code
            + f"""
882
{code_indent}  if(platform::RecordOpInfoSupplement::IsEnabled()){{"""
883
        )
884 885 886 887 888 889 890 891 892 893 894 895
        single_tensor_names = []
        list_tensor_names = []
        for input_name, input_tensors in input_name_tensor_map.items():
            has_vector_tensor = False
            for input_tensor, is_vector in input_tensors:
                if is_vector is True:
                    has_vector_tensor = True
            if has_vector_tensor is False:
                single_tensor_names.append(input_name)
            else:
                list_tensor_names.append(input_name)
        if not single_tensor_names:
896 897 898
            input_tensor_code = (
                input_tensor_code
                + f"""
899
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
900
            )
901
        else:
902 903 904
            for input_name in single_tensor_names:
                if input_name in self.optional_vars:
                    input_tensors = input_name_tensor_map[input_name]
905 906 907
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
908
{code_indent}     std::vector<phi::DDim> {input_name}_record_shapes;"""
909
                    )
910
                    for input_tensor, _ in input_tensors:
911 912 913
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
914 915 916
{code_indent}     if({input_tensor}){{
{code_indent}       {input_name}_record_shapes.push_back((*{input_tensor}).dims());
{code_indent}     }}"""
917
                        )
918

919 920 921
            input_tensor_code = (
                input_tensor_code
                + f"""
922
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
923
            )
924
            for input_name in single_tensor_names[:-1]:
925
                if input_name in self.optional_vars:
926 927 928
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
929
{code_indent}     {{"{input_name}", {input_name}_record_shapes}},"""
930
                    )
931
                else:
932 933 934
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
935
{code_indent}     {{"{input_name}", {{"""
936
                    )
937 938
                    input_tensors = input_name_tensor_map[input_name]
                    for input_tensor, _ in input_tensors[:-1]:
939 940 941
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
942
{code_indent}     (*{input_tensor}).dims(),"""
943 944 945 946
                        )
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
947
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}},"""
948
                    )
949
            if single_tensor_names[-1] in self.optional_vars:
950 951 952
                input_tensor_code = (
                    input_tensor_code
                    + f"""
953
{code_indent}     {{"{single_tensor_names[-1]}",
954
{code_indent}     {single_tensor_names[-1]}_record_shapes}}}};"""
955
                )
956
            else:
957 958 959
                input_tensor_code = (
                    input_tensor_code
                    + f"""
960
{code_indent}     {{"{single_tensor_names[-1]}", {{"""
961
                )
962 963
                input_tensors = input_name_tensor_map[single_tensor_names[-1]]
                for input_tensor, _ in input_tensors[:-1]:
964 965 966
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
967
{code_indent}     (*{input_tensor}).dims(),"""
968 969 970 971
                    )
                input_tensor_code = (
                    input_tensor_code
                    + f"""
972
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}}}};"""
973
                )
974
        if list_tensor_names:
975 976 977
            input_tensor_code = (
                input_tensor_code
                + f"""
978
{code_indent}     std::vector<phi::DDim> ddims_vec;"""
979
            )
980
        for input_name in list_tensor_names:
981 982 983
            input_tensor_code = (
                input_tensor_code
                + f"""
984
{code_indent}     ddims_vec.clear();"""
985
            )
986 987
            for input_tensor, is_vector in input_name_tensor_map[input_name]:
                if is_vector:
988 989 990 991
                    input_tensor_truncate = input_tensor[:-4]
                    if input_name in self.inplace_map.values():
                        input_tensor_truncate = input_tensor

992
                    if input_name in self.optional_vars:
993 994 995
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
996 997 998 999
{code_indent}     if ({input_tensor_truncate}){{
{code_indent}       ddims_vec.reserve({input_tensor_truncate}->size());
{code_indent}       for (size_t i = 0; i < {input_tensor_truncate}->size(); ++i) {{
{code_indent}         ddims_vec.emplace_back((*{input_tensor_truncate}->at(i)).dims());
1000 1001
{code_indent}       }}
{code_indent}     }}"""
1002
                        )
1003
                    else:
1004 1005 1006
                        input_tensor_code = (
                            input_tensor_code
                            + f"""
1007 1008 1009
{code_indent}     ddims_vec.reserve({input_tensor_truncate}.size());
{code_indent}     for (size_t i = 0; i < {input_tensor_truncate}.size(); ++i) {{
{code_indent}       ddims_vec.emplace_back((*{input_tensor_truncate}[i]).dims());
1010
{code_indent}     }}"""
1011
                        )
1012
                else:
1013 1014 1015
                    input_tensor_code = (
                        input_tensor_code
                        + f"""
1016 1017
                  ddims_vec.emplace_back((*{input_tensor}).dims());
{code_indent}     """
1018 1019 1020 1021
                    )
            input_tensor_code = (
                input_tensor_code
                + f"""
1022
{code_indent}     input_shapes.emplace_back("{input_name}", ddims_vec);"""
1023
            )
1024

1025 1026 1027
        input_tensor_code = (
            input_tensor_code
            + f"""
1028
{code_indent}     platform::RecordOpInfoSupplement("{self.api}", input_shapes);
1029
{code_indent}  }}"""
1030
        )
1031
        kernel_args = ["*dev_ctx"]
1032 1033
        for param in kernel_param:
            if param in input_names:
1034
                if param in self.optional_vars:
1035
                    kernel_args.append(PREFIX_TENSOR_NAME + param)
1036
                else:
1037
                    if self.inputs['input_info'][param] == "const Tensor&":
1038
                        kernel_args.append("*" + PREFIX_TENSOR_NAME + param)
1039 1040 1041 1042
                    elif (
                        self.inputs['input_info'][param]
                        == "const std::vector<Tensor>&"
                    ):
1043
                        kernel_args.append(PREFIX_TENSOR_NAME + param)
1044 1045 1046
                    else:
                        # do nothing
                        pass
1047
                # input is dense tensor
1048 1049 1050 1051 1052
                if (
                    kernel_tensor_type is None
                    or kernel_tensor_type[0][kernel_param.index(param)]
                    == 'dense'
                ):
1053
                    kernel_args_type_list.append(
1054 1055
                        dense_input_trans_map[input_infos[param]]
                    )
1056 1057
                else:  # input is selected_rows
                    kernel_args_type_list.append(
1058 1059
                        sr_input_trans_map[input_infos[param]]
                    )
1060 1061
            elif param in attr_names:
                # set attr for kernel_context
1062 1063 1064
                if 'IntArray' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append('const phi::IntArray&')
                    param = 'phi::IntArray(' + param + ')'
1065 1066
                elif 'vector<phi::Scalar>' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append(
1067 1068
                        'const std::vector<phi::Scalar>&'
                    )
1069
                    param = param
1070
                elif 'Scalar' in self.attrs['attr_info'][param][0]:
1071 1072
                    kernel_args_type_list.append('const phi::Scalar&')
                    param = 'phi::Scalar(' + param + ')'
1073
                else:
1074
                    kernel_args_type_list.append(
1075 1076
                        self.attrs['attr_info'][param][0]
                    )
1077
                kernel_args.append(param)
1078
            elif isinstance(param, bool):
1079
                kernel_args.append(str(param).lower())
1080
            else:
1081
                kernel_args.append(str(param))
1082

1083 1084
        for i, out_type in enumerate(self.outputs['types']):
            # output is dense tensor
1085 1086 1087 1088
            if (
                kernel_tensor_type is None
                or kernel_tensor_type[1][i] == 'dense'
            ):
1089 1090 1091
                kernel_args_type_list.append(dense_out_trans_map[out_type])
            else:  # output is selected_rows
                kernel_args_type_list.append(sr_out_trans_map[out_type])
1092 1093 1094

        kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

1095
        return input_tensor_code, ", ".join(kernel_args), kernel_signature
1096

1097 1098
    # Override by child class
    def gene_return_code(self):
1099
        return "return api_output;"
1100

1101
    # Override by child class
1102 1103 1104 1105 1106 1107 1108
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
1109 1110
        return None, None, None

1111 1112
    def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
        kernel_dispatch = self.kernel['dispatch'][kernel_name]
1113
        input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
1114 1115
            kernel_dispatch, code_indent
        )
1116
        out_tensor_type_list = kernel_dispatch[1] if kernel_dispatch else None
1117
        outputs_args, kernel_output_names, output_create = self.gene_output(
1118 1119 1120 1121 1122
            self.outputs['types'],
            out_tensor_type_list,
            code_indent,
            inplace_flag,
        )
1123 1124
        fallback_kernel_output_trans = ""
        for kernel_out in outputs_args:
1125
            fallback_kernel_output_trans += f"""
1126
{code_indent}    TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""
1127 1128 1129 1130 1131
        cudnn_args = (
            ''
            if self.kernel['use_gpudnn'] == 'false'
            else ', ' + self.kernel['use_gpudnn']
        )
1132
        return f"""
F
From00 已提交
1133
{code_indent}  VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
1134
{code_indent}  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
1135
{code_indent}      "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
1136
{code_indent}  const auto& kernel = kernel_result.kernel;
1137
{code_indent}  VLOG(6) << "{kernel_name} kernel: " << kernel;
1138
{code_indent}  auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
1139 1140
{input_tensors}
{output_create}
1141 1142 1143 1144
{code_indent}  paddle::platform::RecordEvent *infer_shape_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    infer_shape_record_event = new paddle::platform::RecordEvent(\"{self.api} infer_meta\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
1145
{self.gene_infer_meta(kernel_output_names, code_indent)}
1146 1147 1148
{code_indent}  if(infer_shape_record_event != nullptr){{
{code_indent}    delete infer_shape_record_event;
{code_indent}  }}
1149 1150
{code_indent}  using kernel_signature = {kernel_signature};
{code_indent}  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
1151 1152 1153 1154
{code_indent}  paddle::platform::RecordEvent* kernel_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    kernel_record_event = new paddle::platform::RecordEvent(\"{self.api} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
1155
{code_indent}    (*kernel_fn)({kernel_args}, {", ".join(outputs_args)});
1156 1157
{code_indent}  if(kernel_record_event != nullptr){{
{code_indent}    delete kernel_record_event;
1158 1159 1160
{code_indent}  }}
{code_indent}  if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
1161
{code_indent}  }}
1162
{code_indent}  {self.gene_return_code()}"""
1163

1164
    def get_condition_code(self, kernel_name):
1165 1166 1167
        assert self.kernel['dispatch'][
            kernel_name
        ], f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'scale' in ops.yaml."
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
        input_types = self.kernel['dispatch'][kernel_name][0]
        condition_list = []
        for i, in_type in enumerate(input_types):
            if in_type == "dense":
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_dense_tensor())"
                    )
                else:
                    condition_list.append(
1178 1179
                        f"{self.inputs['names'][i]}.is_dense_tensor()"
                    )
1180 1181 1182 1183 1184 1185 1186
            else:
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_selected_rows())"
                    )
                else:
                    condition_list.append(
1187 1188
                        f"{self.inputs['names'][i]}.is_selected_rows()"
                    )
1189
        return " && ".join(condition_list)
1190

1191 1192 1193 1194 1195 1196
    def gene_dispatch_code(self, kernel_name, inplace_flag=False):
        return f"""
  if ({self.get_condition_code(kernel_name)}) {{
{self.gen_kernel_code(kernel_name, '  ', inplace_flag)}
  }}
"""
1197

1198
    def gene_base_api_code(self, inplace_flag=False):
1199 1200 1201
        api_func_name = self.get_api_func_name()
        if inplace_flag and api_func_name[-1] != '_':
            api_func_name += '_'
1202
        api_code = f"""
1203
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
1204
{self.gene_kernel_select()}
1205
"""
1206

1207 1208 1209 1210
        if len(self.kernel['func']) > 1:
            kernel_dispatch_code = ''
            for kernel_name in self.kernel['func']:
                kernel_dispatch_code += self.gene_dispatch_code(
1211 1212 1213 1214 1215
                    kernel_name, inplace_flag
                )
            return (
                api_code
                + f"""
1216 1217 1218
{kernel_dispatch_code}
  PADDLE_THROW(phi::errors::Unimplemented(
          "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
1219
}}
1220
"""
1221
            )
1222
        else:
1223 1224 1225 1226
            return (
                api_code
                + self.gen_kernel_code(self.kernel['func'][0], '', inplace_flag)
                + """
1227
}
1228
"""
1229
            )
1230

1231 1232
    def gene_invoke_code(self, invoke_code, params_code):
        return f"""
1233
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
1234 1235 1236
  return {invoke_code};
}}"""

1237 1238 1239
    def gene_api_code(self):
        if self.is_base_api:
            api_code = self.gene_base_api_code()
1240
            if len(self.inplace_map) > 0:
Z
zyfncg 已提交
1241 1242
                if self.api[-1] == '_':
                    api_code = ""
1243 1244 1245
                api_code = api_code + self.gene_base_api_code(inplace_flag=True)
            return api_code

1246
        else:
1247 1248
            invoke_func_name = self.invoke.split('(')[0].strip()
            if invoke_func_name in self.attrs['names']:
1249
                # Adjust the param whose name is same with api invoked.
1250
                pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
1251 1252 1253 1254 1255 1256

                def adjust_name(matched):
                    matched_str = matched.group()
                    return matched_str[0:-1] + '_val' + matched_str[-1]

                invoke_code = re.sub(pattern, adjust_name, self.invoke)
1257 1258 1259
                params_code = re.sub(
                    pattern, adjust_name, self.get_define_args()
                )
1260 1261
            else:
                invoke_code = self.invoke
1262 1263
                params_code = self.get_define_args()
            return self.gene_invoke_code(invoke_code, params_code)