api_base.py 47.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
16
import collections
17

18
PREFIX_TENSOR_NAME = 'input_'
19 20 21 22
PREFIX_META_TENSOR_NAME = 'meta_'


class BaseAPI(object):
23

24 25 26 27 28 29 30 31 32 33 34 35
    def __init__(self, api_item_yaml):
        self.api = self.get_api_name(api_item_yaml)

        # inputs:
        #     names : [], list of input names
        #     input_info : {input_name : type}
        # attrs:
        #     names : [], list of attribute names
        #     attr_info : { attr_name : (type, default_values)}
        # outputs:
        #     names : [], list of output names
        #     types : [], list of output types
36
        #     out_size_expr : [], expression for getting size of vector<Tensor>
37
        self.inputs, self.attrs, self.outputs, self.optional_vars = self.parse_args(
38 39 40 41 42 43 44
            self.api, api_item_yaml)

        self.is_base_api = True
        if 'invoke' in api_item_yaml:
            self.is_base_api = False
            self.invoke = api_item_yaml['invoke']
        else:
45
            if 'infer_meta' in api_item_yaml:
46 47
                self.infer_meta = self.parse_infer_meta(
                    api_item_yaml['infer_meta'])
48 49
            self.kernel = self.parse_kernel(api_item_yaml['kernel'])
            self.data_transform = self.parse_data_transform(api_item_yaml)
50
            self.inplace_map, self.view_map = {}, {}
51

Y
YuanRisheng 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        self.gene_input_func = {
            "const Tensor&": {
                "dense": self.gene_dense_input,
                "selected_rows": self.gene_selected_rows_input
            },
            "const paddle::optional<Tensor>&": {
                "dense": self.gene_dense_input,
                "selected_rows": self.gene_selected_rows_input
            },
            "const std::vector<Tensor>&": {
                "dense": self.gene_vec_dense_input
            },
            "const paddle::optional<std::vector<Tensor>>&": {
                "dense": self.gene_optional_vec_dense_input
            }
        }

69
    def get_api_name(self, api_item_yaml):
70
        return api_item_yaml['op']
71

72 73 74
    def get_api_func_name(self):
        return self.api

75 76 77
    def get_input_tensor_args(self, inplace_flag=False):
        input_args = []
        inplace_type_map = {
78 79 80 81 82 83 84 85
            "const Tensor&":
            "Tensor&",
            "const paddle::optional<Tensor>&":
            "paddle::optional<Tensor>&",
            "const std::vector<Tensor>&":
            "std::vector<Tensor>&",
            "const paddle::optional<std::vector<Tensor>>&":
            "paddle::optional<std::vector<Tensor>>&"
86 87 88 89
        }
        for name in self.inputs['names']:
            name = name.split('@')[0]
            if inplace_flag and name in self.inplace_map.values():
90 91 92
                input_args.append(
                    inplace_type_map[self.inputs['input_info'][name]] + ' ' +
                    name)
93 94 95 96 97 98 99 100 101 102 103 104
            else:
                input_args.append(self.inputs['input_info'][name] + ' ' + name)
        return input_args

    def get_declare_args(self, inplace_flag=False):
        declare_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            default_value = ''
            if self.attrs['attr_info'][name][1] is not None:
                default_value = ' = ' + self.attrs['attr_info'][name][1]
            declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name +
                                default_value)
105

106 107 108 109 110 111 112 113
        return ", ".join(declare_args)

    def get_define_args(self, inplace_flag=False):
        define_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)

        return ", ".join(define_args)
114

115
    def parse_args(self, api_name, api_item_yaml):
116 117 118 119 120
        optional_vars = []
        if 'optional' in api_item_yaml:
            optional_vars = [
                item.strip() for item in api_item_yaml['optional'].split(',')
            ]
121 122 123
        inputs, attrs = self.parse_input_and_attr(api_name,
                                                  api_item_yaml['args'],
                                                  optional_vars)
124
        output_type_list, output_names, out_size_expr = self.parse_output(
125 126 127 128
            api_name, api_item_yaml['output'])
        return inputs, attrs, {
            'names': output_names,
            'types': output_type_list,
129 130
            'out_size_expr': out_size_expr
        }, optional_vars
131

132
    def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
133 134 135 136 137 138 139
        inputs = {'names': [], 'input_info': {}}
        attrs = {'names': [], 'attr_info': {}}
        args_str = args_config.strip()
        assert args_str.startswith('(') and args_str.endswith(')'), \
            f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
        args_str = args_str[1:-1]
        args_list = args_str.split(',')
Z
zyfncg 已提交
140 141 142 143
        input_types_map = {
            'Tensor': 'const Tensor&',
            'Tensor[]': 'const std::vector<Tensor>&'
        }
144
        attr_types_map = {
145
            'IntArray': 'const IntArray&',
146
            'Scalar': 'const Scalar&',
147 148 149 150
            'Scalar(int)': 'const Scalar&',
            'Scalar(int64_t)': 'const Scalar&',
            'Scalar(float)': 'const Scalar&',
            'Scalar(dobule)': 'const Scalar&',
151
            'Scalar[]': 'const std::vector<phi::Scalar>&',
152
            'int': 'int',
153 154
            'int32_t': 'int32_t',
            'int64_t': 'int64_t',
155 156 157
            'long': 'long',
            'size_t': 'size_t',
            'float': 'float',
158
            'float[]': 'const std::vector<float>&',
159 160
            'double': 'double',
            'bool': 'bool',
161
            'str': 'const std::string&',
162
            'str[]': 'const std::vector<std::string>&',
163
            'Place': 'const Place&',
164 165
            'DataLayout': 'DataLayout',
            'DataType': 'DataType',
166
            'int64_t[]': 'const std::vector<int64_t>&',
Z
zhiboniu 已提交
167
            'int[]': 'const std::vector<int>&',
168 169
        }
        optional_types_trans = {
170
            'Tensor': 'const paddle::optional<Tensor>&',
171 172
            'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
            'int': 'paddle::optional<int>',
173 174
            'int32_t': 'paddle::optional<int32_t>',
            'int64_t': 'paddle::optional<int64_t>',
175 176 177
            'float': 'paddle::optional<float>',
            'double': 'paddle::optional<double>',
            'bool': 'paddle::optional<bool>',
178
            'Place': 'paddle::optional<const Place&>',
179
            'DataLayout': 'paddle::optional<DataLayout>',
180
            'DataType': 'paddle::optional<DataType>'
181 182
        }

183 184
        for item in args_list:
            item = item.strip()
Z
zyfncg 已提交
185
            type_and_name = item.split(' ')
186 187
            # match the input tensor
            has_input = False
Z
zyfncg 已提交
188 189 190
            for in_type_symbol, in_type in input_types_map.items():
                if type_and_name[0] == in_type_symbol:
                    input_name = type_and_name[1].strip()
191 192 193 194 195
                    assert len(input_name) > 0, \
                        f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
                    assert len(attrs['names']) == 0, \
                        f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"

196 197 198
                    if input_name in optional_vars:
                        in_type = optional_types_trans[in_type_symbol]

199 200 201 202 203 204 205 206
                    inputs['names'].append(input_name)
                    inputs['input_info'][input_name] = in_type
                    has_input = True
                    break
            if has_input:
                continue

            # match the attribute
Z
zyfncg 已提交
207 208 209
            for attr_type_symbol, attr_type in attr_types_map.items():
                if type_and_name[0] == attr_type_symbol:
                    attr_name = item[len(attr_type_symbol):].strip()
210 211 212 213 214 215 216 217
                    assert len(attr_name) > 0, \
                        f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
                    default_value = None
                    if '=' in attr_name:
                        attr_infos = attr_name.split('=')
                        attr_name = attr_infos[0].strip()
                        default_value = attr_infos[1].strip()

218 219 220
                    if attr_name in optional_vars:
                        attr_type = optional_types_trans[attr_type_symbol]

221 222 223 224 225
                    default_value_str = "" if default_value is None else '=' + default_value
                    attrs['names'].append(attr_name)
                    attrs['attr_info'][attr_name] = (attr_type, default_value)
                    break

226
        return inputs, attrs
227 228

    def parse_output(self, api_name, output_config):
229

230
        def parse_output_item(output_item):
Z
zyfncg 已提交
231 232 233 234
            output_type_map = {
                'Tensor': 'Tensor',
                'Tensor[]': 'std::vector<Tensor>'
            }
235 236 237 238 239 240 241 242 243 244 245 246 247 248
            result = re.search(
                r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
                output_item)
            assert result is not None, f"{api_name} : the output config parse error."
            out_type = result.group('out_type')
            assert out_type in output_type_map, \
                f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
                  but now is {out_type}."

            out_name = 'out' if result.group('name') is None else result.group(
                'name')[1:-1]
            out_size_expr = None if result.group(
                'expr') is None else result.group('expr')[1:-1]
            return output_type_map[out_type], out_name, out_size_expr
249 250 251 252

        temp_list = output_config.split(',')

        if len(temp_list) == 1:
253
            out_type, out_name, size_expr = parse_output_item(temp_list[0])
254
            return [out_type], [out_name], [size_expr]
255 256 257
        else:
            out_type_list = []
            out_name_list = []
258
            out_size_expr_list = []
259
            for output_item in temp_list:
260
                out_type, out_name, size_expr = parse_output_item(output_item)
261 262
                out_type_list.append(out_type)
                out_name_list.append(out_name)
263
                out_size_expr_list.append(size_expr)
264

265
            return out_type_list, out_name_list, out_size_expr_list
266

267 268 269 270 271 272 273 274 275 276 277 278 279 280
    def parse_infer_meta(self, infer_meta_config):
        infer_meta = infer_meta_config
        if 'param' not in infer_meta_config:
            infer_meta['param'] = None

        return infer_meta

    def parse_kernel(self, kernel_config):
        # kernel :
        #    func : [], Kernel functions (example: scale, scale_sr)
        #    param : [], Input params of kernel
        #    backend : str, the names of param to choose the kernel backend, default is None
        #    layout : str, the names of param to choose the kernel layout, default is None
        #    data_type : str, the names of param to choose the kernel data_type, default is None
281
        #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
282 283 284 285 286
        kernel = {
            'func': [],
            'param': None,
            'backend': None,
            'layout': None,
Z
zyfncg 已提交
287
            'data_type': None,
288 289
            'use_gpudnn': 'false',
            'dispatch': {}
290 291 292 293 294 295 296 297 298
        }
        if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
            kernel['backend'] = kernel_config['backend']
        if 'layout' in kernel_config and len(kernel_config['layout']) > 0:
            kernel['layout'] = kernel_config['layout']
        if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0:
            kernel['data_type'] = kernel_config['data_type']
        if 'param' in kernel_config:
            kernel['param'] = kernel_config['param']
299 300 301 302
        if 'use_gpudnn' in kernel_config:
            kernel['use_gpudnn'] = kernel_config['use_gpudnn']
            if isinstance(kernel['use_gpudnn'], bool):
                kernel['use_gpudnn'] = str(kernel['use_gpudnn']).lower()
303 304 305 306 307 308 309 310 311
        kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
            kernel_config['func'])

        def parse_kernel_in_out_type(in_out_str):
            if len(in_out_str) == 0:
                return None
            tmp_in_out_list = in_out_str[1:-1].split('->')
            inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
            outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
312 313 314 315 316 317 318 319 320 321 322

            # check the tensor type
            for item in inputs:
                assert item in [
                    'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
                ], f"{self.api} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
            for item in outputs:
                assert item in [
                    'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
                ], f"{self.api} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."

323 324 325 326 327 328
            return (inputs, outputs)

        for func_item in kernel_funcs:
            kernel['func'].append(func_item[0])
            kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
                func_item[1])
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343

        return kernel

    def parse_data_transform(self, api_item_yaml):
        data_transform = {'skip_transform': [], 'support_trans_dtype': []}
        if 'data_transform' in api_item_yaml:
            if 'skip_transform' in api_item_yaml['data_transform']:
                data_transform['skip_transform'] = api_item_yaml[
                    'data_transform']['skip_transform']
            if 'support_trans_dtype' in api_item_yaml['data_transform']:
                data_transform['support_trans_dtype'] = api_item_yaml[
                    'data_transform']['support_trans_dtype']

        return data_transform

344
    # Override by child class
345
    def get_return_type(self, inplace_flag=False):
346 347 348
        return None

    def gene_api_declaration(self):
349 350 351 352 353
        api_declaration = ""
        api_func_name = self.get_api_func_name()
        if api_func_name[-1] != '_':
            api_declaration = f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
354 355
"""

356 357 358
        if self.is_base_api and len(self.inplace_map) > 0:
            if api_func_name[-1] != '_':
                api_func_name += '_'
359
            api_declaration = api_declaration + f"""
360
PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
361 362 363 364
"""

        return api_declaration

365 366 367 368 369 370 371 372 373
    # Backward API Override this method
    def gene_kernel_backend_select(self):
        backend_select_code = ""
        if self.kernel['backend'] is not None:
            if '>' in self.kernel['backend']:
                vars_list = self.kernel['backend'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
374
                assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
                    f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
                backend_select_code = f"""
  kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                backend_args = [
                    ele.strip() for ele in self.kernel['backend'].split(',')
                ]
                backend_select_code = f"""
  kernel_backend = ParseBackend({", ".join(backend_args)});
"""

        return backend_select_code

390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
    def gene_kernel_select(self) -> str:
        api = self.api
        input_names = self.inputs['names']
        attrs = self.attrs
        kernel = self.kernel

        kernel_key_item_init = """
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;
"""
        # Check the tensor options
        attr_backend_count = 0
        attr_layout_count = 0
        attr_data_type_count = 0
        for attr_name in attrs['names']:
406
            if attrs['attr_info'][attr_name][0] == 'const Place&':
407
                assert kernel['backend'] is not None, \
408
                    f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
409 410 411 412 413 414 415 416 417 418 419
                attr_backend_count = attr_backend_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataLayout':
                assert kernel['layout'] is not None, \
                    f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
                attr_layout_count = attr_layout_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataType':
                assert kernel['data_type'] is not None, \
                    f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
                attr_data_type_count = attr_data_type_count + 1

        # preprocess kernel configures
420
        kernel_select_code = self.gene_kernel_backend_select()
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458

        if kernel['layout'] is not None:
            if '>' in kernel['layout']:
                vars_list = kernel['layout'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}."
                assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataLayout', \
                    f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
                kernel_select_code = kernel_select_code + f"""
  kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                vars_list = kernel['layout'].split(',')
                assert len(
                    vars_list
                ) == 1, f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}."
                kernel_select_code = kernel_select_code + f"""
  kernel_layout = ParseLayout({vars_list[0].strip()});
"""

        if kernel['data_type'] is not None:
            if '>' in kernel['data_type']:
                vars_list = kernel['data_type'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}."
                assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataType', \
                    f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
                kernel_select_code = kernel_select_code + f"""
  kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                vars_list = kernel['data_type'].split(',')
                assert len(
                    vars_list
459
                ) == 1, f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
460 461 462 463 464
                kernel_select_code = kernel_select_code + f"""
  kernel_data_type = ParseDataType({vars_list[0].strip()});
"""

        if len(input_names) == 0:
465
            assert attr_backend_count > 0 and attr_data_type_count > 0, \
466
                f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'."
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

        kernel_select_args = ""
        for input_name in input_names:
            kernel_select_args = kernel_select_args + input_name + ", "

        if len(kernel_select_args) > 2:
            kernel_select_args = kernel_select_args[:-2]

        kernel_select_code = kernel_key_item_init + kernel_select_code

        if len(input_names) > 0:
            kernel_select_code = kernel_select_code + f"""
  if (kernel_backend == Backend::UNDEFINED
        || kernel_layout == DataLayout::UNDEFINED
        || kernel_data_type == DataType::UNDEFINED ) {{
    auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
483
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
484 485 486 487 488 489 490 491 492 493 494 495 496
    if (kernel_backend == Backend::UNDEFINED) {{
      kernel_backend = kernel_key.backend();
    }}
    if (kernel_layout == DataLayout::UNDEFINED) {{
      kernel_layout = kernel_key.layout();
    }}
    if (kernel_data_type == DataType::UNDEFINED) {{
      kernel_data_type = kernel_key.dtype();
    }}
  }}"""

        return kernel_select_code

497
    def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
498 499 500 501
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        infer_meta = self.infer_meta

502 503
        infer_meta_params = infer_meta['param'] if infer_meta[
            'param'] is not None else input_names + attr_names
504 505 506 507 508
        # generate meta tensors
        meta_tensor_code = ""
        param_code = ""
        for param in infer_meta_params:
            if param in input_names:
509 510 511 512 513
                if self.inputs['input_info'][param] == "const Tensor&":
                    param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
                elif self.inputs['input_info'][
                        param] == "const std::vector<Tensor>&":
                    meta_tensor_code = meta_tensor_code + f"""
514
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
515
{code_indent}  std::vector<const phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
516 517 518 519
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas[i] = &{param}_meta_vec[i];
{code_indent}  }}
"""
520 521 522 523 524 525 526 527 528 529
                    param_code = param_code + param + "_metas, "
                elif self.inputs['input_info'][
                        param] == "const paddle::optional<std::vector<Tensor>>&":
                    meta_tensor_code = meta_tensor_code + f"""
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent}  paddle::optional<std::vector<const phi::MetaTensor*>> {param}_metas({param}_meta_vec.size());
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas->at(i) = &{param}_meta_vec[i];
{code_indent}  }}
"""
530 531
                    param_code = param_code + param + "_metas, "
                elif param in self.optional_vars:
532
                    param_code = param_code + "MakeMetaTensor(" + PREFIX_TENSOR_NAME + param + "), "
533
                else:
534 535 536
                    raise ValueError(
                        f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
                    )
537 538 539 540 541 542 543 544 545
            elif param in attr_names:
                param_code = param_code + param + ", "
            elif isinstance(param, str):
                param_code = param_code + "\"" + param + "\", "
            elif isinstance(param, bool):
                param_code = param_code + str(param).lower() + ", "
            else:
                param_code = param_code + str(param) + ", "

546 547 548 549 550 551
        for i, out_name in enumerate(kernel_output_names):
            if self.outputs['types'][i] == 'std::vector<Tensor>':
                meta_tensor_code = meta_tensor_code + f"""
{code_indent}  auto {out_name}_{PREFIX_META_TENSOR_NAME}vec = MakeMetaTensor({out_name});
{code_indent}  std::vector<phi::MetaTensor*> {out_name}_metas({out_name}_{PREFIX_META_TENSOR_NAME}vec.size());
{code_indent}  for (size_t i = 0; i < {out_name}_{PREFIX_META_TENSOR_NAME}vec.size(); ++i) {{
552
{code_indent}    {out_name}_metas[i] = {out_name}[i] ? &{out_name}_{PREFIX_META_TENSOR_NAME}vec[i] : nullptr;
553 554 555 556 557 558 559
{code_indent}  }}"""

                param_code = param_code + out_name + '_metas, '
            else:
                meta_tensor_code = meta_tensor_code + code_indent + "  phi::MetaTensor " + out_name.replace(
                    'kernel_',
                    PREFIX_META_TENSOR_NAME) + "(" + out_name + ");\n"
560 561 562 563
                if len(kernel_output_names) == 1:
                    param_code = param_code + f"&{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)}, "
                else:
                    param_code = param_code + f"{out_name} ? &{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)} : nullptr, "
564

565 566
        param_code = param_code[:-2]
        return f"""{meta_tensor_code}
567
{code_indent}  phi::{infer_meta['func']}({param_code});
568 569
"""

Y
YuanRisheng 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582 583
    def gene_trans_flag(self, input_name):
        trans_flag = "{}"
        if input_name in self.data_transform['skip_transform']:
            trans_flag = "{true}"
        elif input_name in self.data_transform['support_trans_dtype']:
            trans_flag = "{false, true}"
        return trans_flag

    def gene_dense_input(self,
                         input_name,
                         input_name_tensor_map,
                         code_indent=''):
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
584
        input_names = self.inputs['names']
Y
YuanRisheng 已提交
585 586 587 588 589 590 591 592 593 594
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map[input_name].append(
            (f"{PREFIX_TENSOR_NAME}{input_name}", False))
        input_tensor_code = input_tensor_code + f"""
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
        return input_tensor_code
595

Y
YuanRisheng 已提交
596 597 598 599 600 601 602
    def gene_selected_rows_input(self,
                                 input_name,
                                 input_name_tensor_map,
                                 code_indent=''):
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
603 604 605 606 607
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

Y
YuanRisheng 已提交
608 609 610 611 612 613 614 615 616 617 618
        input_name_tensor_map[input_name].append(
            (f"{PREFIX_TENSOR_NAME}{input_name}", False))
        input_tensor_code = input_tensor_code + f"""
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});
"""
        return input_tensor_code

    def gene_optional_vec_dense_input(self,
                                      input_name,
                                      input_name_tensor_map,
                                      code_indent=''):
619
        input_tensor_code = ""
Y
YuanRisheng 已提交
620 621 622 623 624 625 626 627 628 629
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
                (f"{PREFIX_TENSOR_NAME}{input_name}", True))
            input_tensor_code = input_tensor_code + f"""
630
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
Y
YuanRisheng 已提交
631 632 633 634
        else:
            input_name_tensor_map[input_name].append(
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True))
            input_tensor_code = input_tensor_code + f"""
635 636 637 638 639 640 641 642
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent}  paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name};
{code_indent}  if ({PREFIX_TENSOR_NAME}{input_name}_vec){{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name} = paddle::optional<std::vector<const phi::DenseTensor*>>({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}    for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}_vec->size(); ++i) {{
{code_indent}      {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}    }}
{code_indent}  }}"""
Y
YuanRisheng 已提交
643
        return input_tensor_code
644

Y
YuanRisheng 已提交
645 646 647 648 649 650 651 652 653 654 655
    def gene_vec_dense_input(self,
                             input_name,
                             input_name_tensor_map,
                             code_indent=''):
        input_tensor_code = ""
        trans_flag = self.gene_trans_flag(input_name)
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
656

Y
YuanRisheng 已提交
657 658 659 660
        if input_name in self.inplace_map.values():
            input_name_tensor_map[input_name].append(
                (f"{PREFIX_TENSOR_NAME}{input_name}", True))
            input_tensor_code = input_tensor_code + f"""
661
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
Y
YuanRisheng 已提交
662 663 664 665
        else:
            input_name_tensor_map[input_name].append(
                (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True))
            input_tensor_code = input_tensor_code + f"""
666
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
667 668 669 670
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}  for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}  }}"""
Y
YuanRisheng 已提交
671
        return input_tensor_code
672

Y
YuanRisheng 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
    def gene_input(self, kernel_tensor_type=None, code_indent=''):
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names
        input_name_tensor_map = collections.defaultdict(list)
        input_tensor_code = ""
        for i, input_name in enumerate(input_names):
            # set input code
            if input_name in kernel_param:
                # input is dense tensor
                api_tensor_type = self.inputs['input_info'][input_name]
                phi_tensor_type = 'dense' if kernel_tensor_type is None else kernel_tensor_type[
                    0][kernel_param.index(input_name)]
                if api_tensor_type in self.gene_input_func.keys():
                    input_tensor_code += self.gene_input_func[api_tensor_type][
                        phi_tensor_type](input_name, input_name_tensor_map,
                                         code_indent)
                else:
                    # do nothing
                    pass
695 696 697 698 699
            else:
                if input_name in self.infer_meta['param']:
                    if input_name in self.optional_vars:
                        input_tensor_code = input_tensor_code + f"""
{code_indent}  paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
700

701
                    else:
702 703 704 705 706 707 708
                        if self.inputs['input_info'][
                                input_name] == "const std::vector<Tensor>&":
                            input_tensor_code = input_tensor_code + f"""
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_uq_ptr = TensorToDenseTensor({input_name});
{code_indent}  const auto& {PREFIX_TENSOR_NAME}{input_name} = *{PREFIX_TENSOR_NAME}{input_name}_uq_ptr;"""
                        else:
                            input_tensor_code = input_tensor_code + f"""
709
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
Y
YuanRisheng 已提交
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748

        return input_name_tensor_map, input_tensor_code

    def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
        dense_input_trans_map = {
            'const Tensor&':
            'const phi::DenseTensor&',
            'const std::vector<Tensor>&':
            'const std::vector<const phi::DenseTensor*>&',
            'const paddle::optional<Tensor&>':
            'paddle::optional<const phi::DenseTensor&>',
            'const paddle::optional<Tensor>&':
            'const paddle::optional<phi::DenseTensor>&',
            'const paddle::optional<std::vector<Tensor>>&':
            'const paddle::optional<std::vector<const phi::DenseTensor*>>&'
        }
        dense_out_trans_map = {
            'Tensor': 'phi::DenseTensor*',
            'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&'
        }
        sr_input_trans_map = {
            'const Tensor&':
            'const phi::SelectedRows&',
            'const paddle::optional<Tensor>&':
            'const paddle::optional<phi::SelectedRows>&'
        }
        sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
        input_names = self.inputs['names']
        input_infos = self.inputs['input_info']
        kernel_args_type_list = ['const platform::DeviceContext&']

        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_name_tensor_map, input_tensor_code = self.gene_input(
            kernel_tensor_type, code_indent)

749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765
        input_tensor_code = input_tensor_code + f"""
{code_indent}  if(platform::RecordOpInfoSupplement::IsEnabled()){{"""
        single_tensor_names = []
        list_tensor_names = []
        for input_name, input_tensors in input_name_tensor_map.items():
            has_vector_tensor = False
            for input_tensor, is_vector in input_tensors:
                if is_vector is True:
                    has_vector_tensor = True
            if has_vector_tensor is False:
                single_tensor_names.append(input_name)
            else:
                list_tensor_names.append(input_name)
        if not single_tensor_names:
            input_tensor_code = input_tensor_code + f"""
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
        else:
766 767 768 769 770 771 772 773 774 775 776
            for input_name in single_tensor_names:
                if input_name in self.optional_vars:
                    input_tensors = input_name_tensor_map[input_name]
                    input_tensor_code = input_tensor_code + f"""
{code_indent}     std::vector<phi::DDim> {input_name}_record_shapes;"""
                    for input_tensor, _ in input_tensors:
                        input_tensor_code = input_tensor_code + f"""
{code_indent}     if({input_tensor}){{
{code_indent}       {input_name}_record_shapes.push_back((*{input_tensor}).dims());
{code_indent}     }}"""

777 778 779
            input_tensor_code = input_tensor_code + f"""
{code_indent}     std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
            for input_name in single_tensor_names[:-1]:
780
                if input_name in self.optional_vars:
781
                    input_tensor_code = input_tensor_code + f"""
782 783
{code_indent}     {{"{input_name}", {input_name}_record_shapes}},"""
                else:
784
                    input_tensor_code = input_tensor_code + f"""
785 786 787
{code_indent}     {{"{input_name}", {{"""
                    input_tensors = input_name_tensor_map[input_name]
                    for input_tensor, _ in input_tensors[:-1]:
788
                        input_tensor_code = input_tensor_code + f"""
789
{code_indent}     (*{input_tensor}).dims(),"""
790
                    input_tensor_code = input_tensor_code + f"""
791
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}},"""
792
            if single_tensor_names[-1] in self.optional_vars:
793 794
                input_tensor_code = input_tensor_code + f"""
{code_indent}     {{"{single_tensor_names[-1]}",
795 796
{code_indent}     {single_tensor_names[-1]}_record_shapes}}}};"""
            else:
797
                input_tensor_code = input_tensor_code + f"""
798 799 800
{code_indent}     {{"{single_tensor_names[-1]}", {{"""
                input_tensors = input_name_tensor_map[single_tensor_names[-1]]
                for input_tensor, _ in input_tensors[:-1]:
801
                    input_tensor_code = input_tensor_code + f"""
802
{code_indent}     (*{input_tensor}).dims(),"""
803
                input_tensor_code = input_tensor_code + f"""
804 805 806 807 808 809 810 811 812
{code_indent}     (*{input_tensors[-1][0]}).dims()}}}}}};"""
        if list_tensor_names:
            input_tensor_code = input_tensor_code + f"""
{code_indent}     std::vector<phi::DDim> ddims_vec;"""
        for input_name in list_tensor_names:
            input_tensor_code = input_tensor_code + f"""
{code_indent}     ddims_vec.clear();"""
            for input_tensor, is_vector in input_name_tensor_map[input_name]:
                if is_vector:
813 814 815 816
                    input_tensor_truncate = input_tensor[:-4]
                    if input_name in self.inplace_map.values():
                        input_tensor_truncate = input_tensor

817 818
                    if input_name in self.optional_vars:
                        input_tensor_code = input_tensor_code + f"""
819 820 821 822
{code_indent}     if ({input_tensor_truncate}){{
{code_indent}       ddims_vec.reserve({input_tensor_truncate}->size());
{code_indent}       for (size_t i = 0; i < {input_tensor_truncate}->size(); ++i) {{
{code_indent}         ddims_vec.emplace_back((*{input_tensor_truncate}->at(i)).dims());
823 824 825 826
{code_indent}       }}
{code_indent}     }}"""
                    else:
                        input_tensor_code = input_tensor_code + f"""
827 828 829
{code_indent}     ddims_vec.reserve({input_tensor_truncate}.size());
{code_indent}     for (size_t i = 0; i < {input_tensor_truncate}.size(); ++i) {{
{code_indent}       ddims_vec.emplace_back((*{input_tensor_truncate}[i]).dims());
830 831
{code_indent}     }}"""
                else:
832
                    input_tensor_code = input_tensor_code + f"""
833 834 835 836 837
                  ddims_vec.emplace_back((*{input_tensor}).dims());
{code_indent}     """
            input_tensor_code = input_tensor_code + f"""
{code_indent}     input_shapes.emplace_back("{input_name}", ddims_vec);"""

838 839
        input_tensor_code = input_tensor_code + f"""
{code_indent}     platform::RecordOpInfoSupplement("{self.api}", input_shapes);
840
{code_indent}  }}"""
841
        kernel_args = ["*dev_ctx"]
842 843
        for param in kernel_param:
            if param in input_names:
844
                if param in self.optional_vars:
845
                    kernel_args.append(PREFIX_TENSOR_NAME + param)
846
                else:
847
                    if self.inputs['input_info'][param] == "const Tensor&":
848
                        kernel_args.append("*" + PREFIX_TENSOR_NAME + param)
849
                    elif self.inputs['input_info'][
850
                            param] == "const std::vector<Tensor>&":
851
                        kernel_args.append(PREFIX_TENSOR_NAME + param)
852 853 854
                    else:
                        # do nothing
                        pass
855 856 857
                # input is dense tensor
                if kernel_tensor_type is None or kernel_tensor_type[0][
                        kernel_param.index(param)] == 'dense':
858
                    kernel_args_type_list.append(
859 860 861 862
                        dense_input_trans_map[input_infos[param]])
                else:  # input is selected_rows
                    kernel_args_type_list.append(
                        sr_input_trans_map[input_infos[param]])
863 864
            elif param in attr_names:
                # set attr for kernel_context
865 866 867
                if 'IntArray' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append('const phi::IntArray&')
                    param = 'phi::IntArray(' + param + ')'
868 869 870 871
                elif 'vector<phi::Scalar>' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append(
                        'const std::vector<phi::Scalar>&')
                    param = param
872
                elif 'Scalar' in self.attrs['attr_info'][param][0]:
873 874
                    kernel_args_type_list.append('const phi::Scalar&')
                    param = 'phi::Scalar(' + param + ')'
875
                else:
876 877
                    kernel_args_type_list.append(
                        self.attrs['attr_info'][param][0])
878
                kernel_args.append(param)
879
            elif isinstance(param, bool):
880
                kernel_args.append(str(param).lower())
881
            else:
882
                kernel_args.append(str(param))
883

884 885 886 887 888 889
        for i, out_type in enumerate(self.outputs['types']):
            # output is dense tensor
            if kernel_tensor_type is None or kernel_tensor_type[1][i] == 'dense':
                kernel_args_type_list.append(dense_out_trans_map[out_type])
            else:  # output is selected_rows
                kernel_args_type_list.append(sr_out_trans_map[out_type])
890 891 892

        kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

893
        return input_tensor_code, ", ".join(kernel_args), kernel_signature
894

895 896
    # Override by child class
    def gene_return_code(self):
897
        return "return api_output;"
898

899
    # Override by child class
900
    def gene_output(self,
901 902 903
                    out_dtype_list,
                    out_tensor_type_list=None,
                    code_indent='',
904
                    inplace_flag=False):
905 906
        return None, None, None

907 908
    def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
        kernel_dispatch = self.kernel['dispatch'][kernel_name]
909
        input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
910 911
            kernel_dispatch, code_indent)
        out_tensor_type_list = kernel_dispatch[1] if kernel_dispatch else None
912
        outputs_args, kernel_output_names, output_create = self.gene_output(
913 914
            self.outputs['types'], out_tensor_type_list, code_indent,
            inplace_flag)
915 916 917 918 919
        fallback_kernel_output_trans = ""
        for kernel_out in outputs_args:
            fallback_kernel_output_trans += (f"""
{code_indent}    TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""
                                             )
Z
zyfncg 已提交
920
        cudnn_args = '' if self.kernel[
921
            'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn']
922
        return f"""
F
From00 已提交
923
{code_indent}  VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
924
{code_indent}  auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
925
{code_indent}      "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
926
{code_indent}  const auto& kernel = kernel_result.kernel;
927
{code_indent}  VLOG(6) << "{kernel_name} kernel: " << kernel;
928
{code_indent}  auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
929 930
{input_tensors}
{output_create}
931 932 933 934
{code_indent}  paddle::platform::RecordEvent *infer_shape_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    infer_shape_record_event = new paddle::platform::RecordEvent(\"{self.api} infer_meta\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
935
{self.gene_infer_meta(kernel_output_names, code_indent)}
936 937 938
{code_indent}  if(infer_shape_record_event != nullptr){{
{code_indent}    delete infer_shape_record_event;
{code_indent}  }}
939 940
{code_indent}  using kernel_signature = {kernel_signature};
{code_indent}  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
941 942 943 944
{code_indent}  paddle::platform::RecordEvent* kernel_record_event = nullptr;
{code_indent}  if(paddle::platform::RecordEvent::IsEnabled()){{
{code_indent}    kernel_record_event = new paddle::platform::RecordEvent(\"{self.api} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent}  }}
945
{code_indent}    (*kernel_fn)({kernel_args}, {", ".join(outputs_args)});
946 947
{code_indent}  if(kernel_record_event != nullptr){{
{code_indent}    delete kernel_record_event;
948 949 950
{code_indent}  }}
{code_indent}  if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
951
{code_indent}  }}
952
{code_indent}  {self.gene_return_code()}"""
953

954 955
    def get_condition_code(self, kernel_name):
        assert self.kernel['dispatch'][kernel_name], \
C
Chen Weihang 已提交
956
                f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'scale' in ops.yaml."
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
        input_types = self.kernel['dispatch'][kernel_name][0]
        condition_list = []
        for i, in_type in enumerate(input_types):
            if in_type == "dense":
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_dense_tensor())"
                    )
                else:
                    condition_list.append(
                        f"{self.inputs['names'][i]}.is_dense_tensor()")
            else:
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_selected_rows())"
                    )
                else:
                    condition_list.append(
                        f"{self.inputs['names'][i]}.is_selected_rows()")
        return " && ".join(condition_list)
977

978 979 980 981 982 983
    def gene_dispatch_code(self, kernel_name, inplace_flag=False):
        return f"""
  if ({self.get_condition_code(kernel_name)}) {{
{self.gen_kernel_code(kernel_name, '  ', inplace_flag)}
  }}
"""
984

985
    def gene_base_api_code(self, inplace_flag=False):
986 987 988
        api_func_name = self.get_api_func_name()
        if inplace_flag and api_func_name[-1] != '_':
            api_func_name += '_'
989
        api_code = f"""
990
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
991
{self.gene_kernel_select()}
992
"""
993

994 995 996 997 998
        if len(self.kernel['func']) > 1:
            kernel_dispatch_code = ''
            for kernel_name in self.kernel['func']:
                kernel_dispatch_code += self.gene_dispatch_code(
                    kernel_name, inplace_flag)
999
            return api_code + f"""
1000 1001 1002
{kernel_dispatch_code}
  PADDLE_THROW(phi::errors::Unimplemented(
          "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
1003
}}
1004
"""
1005
        else:
1006 1007
            return api_code + self.gen_kernel_code(self.kernel['func'][0], '',
                                                   inplace_flag) + """
1008
}
1009 1010
"""

1011 1012
    def gene_invoke_code(self, invoke_code, params_code):
        return f"""
1013
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
1014 1015 1016
  return {invoke_code};
}}"""

1017 1018 1019
    def gene_api_code(self):
        if self.is_base_api:
            api_code = self.gene_base_api_code()
1020
            if len(self.inplace_map) > 0:
Z
zyfncg 已提交
1021 1022
                if self.api[-1] == '_':
                    api_code = ""
1023 1024 1025
                api_code = api_code + self.gene_base_api_code(inplace_flag=True)
            return api_code

1026
        else:
1027 1028
            invoke_func_name = self.invoke.split('(')[0].strip()
            if invoke_func_name in self.attrs['names']:
1029
                # Adjust the param whose name is same with api invoked.
1030
                pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
1031 1032 1033 1034 1035 1036 1037

                def adjust_name(matched):
                    matched_str = matched.group()
                    return matched_str[0:-1] + '_val' + matched_str[-1]

                invoke_code = re.sub(pattern, adjust_name, self.invoke)
                params_code = re.sub(pattern, adjust_name,
1038
                                     self.get_define_args())
1039 1040
            else:
                invoke_code = self.invoke
1041 1042
                params_code = self.get_define_args()
            return self.gene_invoke_code(invoke_code, params_code)