api_base.py 35.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

17
PREFIX_TENSOR_NAME = 'input_'
18 19 20 21
PREFIX_META_TENSOR_NAME = 'meta_'


class BaseAPI(object):
22

23 24 25 26 27 28 29 30 31 32 33 34
    def __init__(self, api_item_yaml):
        self.api = self.get_api_name(api_item_yaml)

        # inputs:
        #     names : [], list of input names
        #     input_info : {input_name : type}
        # attrs:
        #     names : [], list of attribute names
        #     attr_info : { attr_name : (type, default_values)}
        # outputs:
        #     names : [], list of output names
        #     types : [], list of output types
35
        #     out_size_expr : [], expression for getting size of vector<Tensor>
36
        self.inputs, self.attrs, self.outputs, self.optional_vars = self.parse_args(
37 38 39 40 41 42 43
            self.api, api_item_yaml)

        self.is_base_api = True
        if 'invoke' in api_item_yaml:
            self.is_base_api = False
            self.invoke = api_item_yaml['invoke']
        else:
44
            if 'infer_meta' in api_item_yaml:
45 46
                self.infer_meta = self.parse_infer_meta(
                    api_item_yaml['infer_meta'])
47 48
            self.kernel = self.parse_kernel(api_item_yaml['kernel'])
            self.data_transform = self.parse_data_transform(api_item_yaml)
49
            self.inplace_map, self.view_map = {}, {}
50 51 52 53

    def get_api_name(self, api_item_yaml):
        return api_item_yaml['api']

54 55 56
    def get_api_func_name(self):
        return self.api

57 58 59 60
    def get_input_tensor_args(self, inplace_flag=False):
        input_args = []
        inplace_type_map = {
            "const Tensor&": "Tensor&",
61
            "const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
62 63 64 65 66
            "const std::vector<Tensor>&": "std::vector<Tensor>&"
        }
        for name in self.inputs['names']:
            name = name.split('@')[0]
            if inplace_flag and name in self.inplace_map.values():
67 68 69
                input_args.append(
                    inplace_type_map[self.inputs['input_info'][name]] + ' ' +
                    name)
70 71 72 73 74 75 76 77 78 79 80 81
            else:
                input_args.append(self.inputs['input_info'][name] + ' ' + name)
        return input_args

    def get_declare_args(self, inplace_flag=False):
        declare_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            default_value = ''
            if self.attrs['attr_info'][name][1] is not None:
                default_value = ' = ' + self.attrs['attr_info'][name][1]
            declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name +
                                default_value)
82

83 84 85 86 87 88 89 90
        return ", ".join(declare_args)

    def get_define_args(self, inplace_flag=False):
        define_args = self.get_input_tensor_args(inplace_flag)
        for name in self.attrs['names']:
            define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)

        return ", ".join(define_args)
91

92
    def parse_args(self, api_name, api_item_yaml):
93 94 95 96 97
        optional_vars = []
        if 'optional' in api_item_yaml:
            optional_vars = [
                item.strip() for item in api_item_yaml['optional'].split(',')
            ]
98 99 100
        inputs, attrs = self.parse_input_and_attr(api_name,
                                                  api_item_yaml['args'],
                                                  optional_vars)
101
        output_type_list, output_names, out_size_expr = self.parse_output(
102 103 104 105
            api_name, api_item_yaml['output'])
        return inputs, attrs, {
            'names': output_names,
            'types': output_type_list,
106 107
            'out_size_expr': out_size_expr
        }, optional_vars
108

109
    def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
110 111 112 113 114 115 116
        inputs = {'names': [], 'input_info': {}}
        attrs = {'names': [], 'attr_info': {}}
        args_str = args_config.strip()
        assert args_str.startswith('(') and args_str.endswith(')'), \
            f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
        args_str = args_str[1:-1]
        args_list = args_str.split(',')
Z
zyfncg 已提交
117 118 119 120
        input_types_map = {
            'Tensor': 'const Tensor&',
            'Tensor[]': 'const std::vector<Tensor>&'
        }
121
        attr_types_map = {
122
            'IntArray': 'const IntArray&',
123
            'Scalar': 'const Scalar&',
124 125 126 127
            'Scalar(int)': 'const Scalar&',
            'Scalar(int64_t)': 'const Scalar&',
            'Scalar(float)': 'const Scalar&',
            'Scalar(dobule)': 'const Scalar&',
128
            'int': 'int',
129 130
            'int32_t': 'int32_t',
            'int64_t': 'int64_t',
131 132 133 134 135
            'long': 'long',
            'size_t': 'size_t',
            'float': 'float',
            'double': 'double',
            'bool': 'bool',
136
            'str': 'const std::string&',
137
            'Place': 'const Place&',
138 139
            'DataLayout': 'DataLayout',
            'DataType': 'DataType',
140 141
            'int64_t[]': 'const std::vector<int64_t>&',
            'int[]': 'const std::vector<int>&'
142 143
        }
        optional_types_trans = {
144
            'Tensor': 'const paddle::optional<Tensor>&',
145 146
            'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
            'int': 'paddle::optional<int>',
147 148
            'int32_t': 'paddle::optional<int32_t>',
            'int64_t': 'paddle::optional<int64_t>',
149 150 151
            'float': 'paddle::optional<float>',
            'double': 'paddle::optional<double>',
            'bool': 'paddle::optional<bool>',
152
            'Place': 'paddle::optional<const Place&>',
153
            'DataLayout': 'paddle::optional<DataLayout>',
154
            'DataType': 'paddle::optional<DataType>'
155 156
        }

157 158
        for item in args_list:
            item = item.strip()
Z
zyfncg 已提交
159
            type_and_name = item.split(' ')
160 161
            # match the input tensor
            has_input = False
Z
zyfncg 已提交
162 163 164
            for in_type_symbol, in_type in input_types_map.items():
                if type_and_name[0] == in_type_symbol:
                    input_name = type_and_name[1].strip()
165 166 167 168 169
                    assert len(input_name) > 0, \
                        f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
                    assert len(attrs['names']) == 0, \
                        f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"

170 171 172
                    if input_name in optional_vars:
                        in_type = optional_types_trans[in_type_symbol]

173 174 175 176 177 178 179 180
                    inputs['names'].append(input_name)
                    inputs['input_info'][input_name] = in_type
                    has_input = True
                    break
            if has_input:
                continue

            # match the attribute
Z
zyfncg 已提交
181 182 183
            for attr_type_symbol, attr_type in attr_types_map.items():
                if type_and_name[0] == attr_type_symbol:
                    attr_name = item[len(attr_type_symbol):].strip()
184 185 186 187 188 189 190 191
                    assert len(attr_name) > 0, \
                        f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
                    default_value = None
                    if '=' in attr_name:
                        attr_infos = attr_name.split('=')
                        attr_name = attr_infos[0].strip()
                        default_value = attr_infos[1].strip()

192 193 194
                    if attr_name in optional_vars:
                        attr_type = optional_types_trans[attr_type_symbol]

195 196 197 198 199
                    default_value_str = "" if default_value is None else '=' + default_value
                    attrs['names'].append(attr_name)
                    attrs['attr_info'][attr_name] = (attr_type, default_value)
                    break

200
        return inputs, attrs
201 202

    def parse_output(self, api_name, output_config):
203

204
        def parse_output_item(output_item):
Z
zyfncg 已提交
205 206 207 208
            output_type_map = {
                'Tensor': 'Tensor',
                'Tensor[]': 'std::vector<Tensor>'
            }
209 210 211 212 213 214 215 216 217 218 219 220 221 222
            result = re.search(
                r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
                output_item)
            assert result is not None, f"{api_name} : the output config parse error."
            out_type = result.group('out_type')
            assert out_type in output_type_map, \
                f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \
                  but now is {out_type}."

            out_name = 'out' if result.group('name') is None else result.group(
                'name')[1:-1]
            out_size_expr = None if result.group(
                'expr') is None else result.group('expr')[1:-1]
            return output_type_map[out_type], out_name, out_size_expr
223 224 225 226

        temp_list = output_config.split(',')

        if len(temp_list) == 1:
227
            out_type, out_name, size_expr = parse_output_item(temp_list[0])
228
            return [out_type], [out_name], [size_expr]
229 230 231
        else:
            out_type_list = []
            out_name_list = []
232
            out_size_expr_list = []
233
            for output_item in temp_list:
234
                out_type, out_name, size_expr = parse_output_item(output_item)
235 236
                out_type_list.append(out_type)
                out_name_list.append(out_name)
237
                out_size_expr_list.append(size_expr)
238

239
            return out_type_list, out_name_list, out_size_expr_list
240

241 242 243 244 245 246 247 248 249 250 251 252 253 254
    def parse_infer_meta(self, infer_meta_config):
        infer_meta = infer_meta_config
        if 'param' not in infer_meta_config:
            infer_meta['param'] = None

        return infer_meta

    def parse_kernel(self, kernel_config):
        # kernel :
        #    func : [], Kernel functions (example: scale, scale_sr)
        #    param : [], Input params of kernel
        #    backend : str, the names of param to choose the kernel backend, default is None
        #    layout : str, the names of param to choose the kernel layout, default is None
        #    data_type : str, the names of param to choose the kernel data_type, default is None
255
        #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
256 257 258 259 260
        kernel = {
            'func': [],
            'param': None,
            'backend': None,
            'layout': None,
Z
zyfncg 已提交
261
            'data_type': None,
262 263
            'use_gpudnn': 'false',
            'dispatch': {}
264 265 266 267 268 269 270 271 272
        }
        if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
            kernel['backend'] = kernel_config['backend']
        if 'layout' in kernel_config and len(kernel_config['layout']) > 0:
            kernel['layout'] = kernel_config['layout']
        if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0:
            kernel['data_type'] = kernel_config['data_type']
        if 'param' in kernel_config:
            kernel['param'] = kernel_config['param']
273 274 275 276
        if 'use_gpudnn' in kernel_config:
            kernel['use_gpudnn'] = kernel_config['use_gpudnn']
            if isinstance(kernel['use_gpudnn'], bool):
                kernel['use_gpudnn'] = str(kernel['use_gpudnn']).lower()
277 278 279 280 281 282 283 284 285
        kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
            kernel_config['func'])

        def parse_kernel_in_out_type(in_out_str):
            if len(in_out_str) == 0:
                return None
            tmp_in_out_list = in_out_str[1:-1].split('->')
            inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
            outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
286 287 288 289 290 291 292 293 294 295 296

            # check the tensor type
            for item in inputs:
                assert item in [
                    'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
                ], f"{self.api} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
            for item in outputs:
                assert item in [
                    'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
                ], f"{self.api} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."

297 298 299 300 301 302
            return (inputs, outputs)

        for func_item in kernel_funcs:
            kernel['func'].append(func_item[0])
            kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
                func_item[1])
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

        return kernel

    def parse_data_transform(self, api_item_yaml):
        data_transform = {'skip_transform': [], 'support_trans_dtype': []}
        if 'data_transform' in api_item_yaml:
            if 'skip_transform' in api_item_yaml['data_transform']:
                data_transform['skip_transform'] = api_item_yaml[
                    'data_transform']['skip_transform']
            if 'support_trans_dtype' in api_item_yaml['data_transform']:
                data_transform['support_trans_dtype'] = api_item_yaml[
                    'data_transform']['support_trans_dtype']

        return data_transform

318
    # Override by child class
319
    def get_return_type(self, inplace_flag=False):
320 321 322
        return None

    def gene_api_declaration(self):
323 324 325 326 327
        api_declaration = ""
        api_func_name = self.get_api_func_name()
        if api_func_name[-1] != '_':
            api_declaration = f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
328 329
"""

330 331 332
        if self.is_base_api and len(self.inplace_map) > 0:
            if api_func_name[-1] != '_':
                api_func_name += '_'
333
            api_declaration = api_declaration + f"""
334
PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
335 336 337 338
"""

        return api_declaration

339 340 341 342 343 344 345 346 347
    # Backward API Override this method
    def gene_kernel_backend_select(self):
        backend_select_code = ""
        if self.kernel['backend'] is not None:
            if '>' in self.kernel['backend']:
                vars_list = self.kernel['backend'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
348
                assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
                    f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
                backend_select_code = f"""
  kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                backend_args = [
                    ele.strip() for ele in self.kernel['backend'].split(',')
                ]
                backend_select_code = f"""
  kernel_backend = ParseBackend({", ".join(backend_args)});
"""

        return backend_select_code

364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
    def gene_kernel_select(self) -> str:
        api = self.api
        input_names = self.inputs['names']
        attrs = self.attrs
        kernel = self.kernel

        kernel_key_item_init = """
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;
"""
        # Check the tensor options
        attr_backend_count = 0
        attr_layout_count = 0
        attr_data_type_count = 0
        for attr_name in attrs['names']:
380
            if attrs['attr_info'][attr_name][0] == 'const Place&':
381
                assert kernel['backend'] is not None, \
382
                    f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
383 384 385 386 387 388 389 390 391 392 393
                attr_backend_count = attr_backend_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataLayout':
                assert kernel['layout'] is not None, \
                    f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
                attr_layout_count = attr_layout_count + 1
            if attrs['attr_info'][attr_name][0] == 'DataType':
                assert kernel['data_type'] is not None, \
                    f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
                attr_data_type_count = attr_data_type_count + 1

        # preprocess kernel configures
394
        kernel_select_code = self.gene_kernel_backend_select()
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432

        if kernel['layout'] is not None:
            if '>' in kernel['layout']:
                vars_list = kernel['layout'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}."
                assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataLayout', \
                    f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
                kernel_select_code = kernel_select_code + f"""
  kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                vars_list = kernel['layout'].split(',')
                assert len(
                    vars_list
                ) == 1, f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}."
                kernel_select_code = kernel_select_code + f"""
  kernel_layout = ParseLayout({vars_list[0].strip()});
"""

        if kernel['data_type'] is not None:
            if '>' in kernel['data_type']:
                vars_list = kernel['data_type'].split('>')
                assert len(
                    vars_list
                ) == 2, f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}."
                assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataType', \
                    f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
                kernel_select_code = kernel_select_code + f"""
  kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

            else:
                vars_list = kernel['data_type'].split(',')
                assert len(
                    vars_list
433
                ) == 1, f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
434 435 436 437 438
                kernel_select_code = kernel_select_code + f"""
  kernel_data_type = ParseDataType({vars_list[0].strip()});
"""

        if len(input_names) == 0:
439
            assert attr_backend_count > 0 and attr_data_type_count > 0, \
440
                f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'."
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456

        kernel_select_args = ""
        for input_name in input_names:
            kernel_select_args = kernel_select_args + input_name + ", "

        if len(kernel_select_args) > 2:
            kernel_select_args = kernel_select_args[:-2]

        kernel_select_code = kernel_key_item_init + kernel_select_code

        if len(input_names) > 0:
            kernel_select_code = kernel_select_code + f"""
  if (kernel_backend == Backend::UNDEFINED
        || kernel_layout == DataLayout::UNDEFINED
        || kernel_data_type == DataType::UNDEFINED ) {{
    auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
457
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
458 459 460 461 462 463 464 465 466 467 468 469 470
    if (kernel_backend == Backend::UNDEFINED) {{
      kernel_backend = kernel_key.backend();
    }}
    if (kernel_layout == DataLayout::UNDEFINED) {{
      kernel_layout = kernel_key.layout();
    }}
    if (kernel_data_type == DataType::UNDEFINED) {{
      kernel_data_type = kernel_key.dtype();
    }}
  }}"""

        return kernel_select_code

471
    def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
472 473 474 475
        input_names = self.inputs['names']
        attr_names = self.attrs['names']
        infer_meta = self.infer_meta

476 477
        infer_meta_params = infer_meta['param'] if infer_meta[
            'param'] is not None else input_names + attr_names
478 479 480 481 482
        # generate meta tensors
        meta_tensor_code = ""
        param_code = ""
        for param in infer_meta_params:
            if param in input_names:
483 484 485 486 487
                if self.inputs['input_info'][param] == "const Tensor&":
                    param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
                elif self.inputs['input_info'][
                        param] == "const std::vector<Tensor>&":
                    meta_tensor_code = meta_tensor_code + f"""
488
{code_indent}  auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
489
{code_indent}  std::vector<const phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
490 491 492 493 494 495 496
{code_indent}  for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent}    {param}_metas[i] = &{param}_meta_vec[i];
{code_indent}  }}
"""

                    param_code = param_code + param + "_metas, "
                elif param in self.optional_vars:
497
                    param_code = param_code + "MakeMetaTensor(" + PREFIX_TENSOR_NAME + param + "), "
498
                else:
499 500 501
                    raise ValueError(
                        f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
                    )
502 503 504 505 506 507 508 509 510
            elif param in attr_names:
                param_code = param_code + param + ", "
            elif isinstance(param, str):
                param_code = param_code + "\"" + param + "\", "
            elif isinstance(param, bool):
                param_code = param_code + str(param).lower() + ", "
            else:
                param_code = param_code + str(param) + ", "

511 512 513 514 515 516
        for i, out_name in enumerate(kernel_output_names):
            if self.outputs['types'][i] == 'std::vector<Tensor>':
                meta_tensor_code = meta_tensor_code + f"""
{code_indent}  auto {out_name}_{PREFIX_META_TENSOR_NAME}vec = MakeMetaTensor({out_name});
{code_indent}  std::vector<phi::MetaTensor*> {out_name}_metas({out_name}_{PREFIX_META_TENSOR_NAME}vec.size());
{code_indent}  for (size_t i = 0; i < {out_name}_{PREFIX_META_TENSOR_NAME}vec.size(); ++i) {{
517
{code_indent}    {out_name}_metas[i] = {out_name}[i] ? &{out_name}_{PREFIX_META_TENSOR_NAME}vec[i] : nullptr;
518 519 520 521 522 523 524
{code_indent}  }}"""

                param_code = param_code + out_name + '_metas, '
            else:
                meta_tensor_code = meta_tensor_code + code_indent + "  phi::MetaTensor " + out_name.replace(
                    'kernel_',
                    PREFIX_META_TENSOR_NAME) + "(" + out_name + ");\n"
525 526 527 528
                if len(kernel_output_names) == 1:
                    param_code = param_code + f"&{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)}, "
                else:
                    param_code = param_code + f"{out_name} ? &{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)} : nullptr, "
529

530 531
        param_code = param_code[:-2]
        return f"""{meta_tensor_code}
532
{code_indent}  phi::{infer_meta['func']}({param_code});
533 534
"""

535 536
    def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
        dense_input_trans_map = {
537 538
            'const Tensor&':
            'const phi::DenseTensor&',
539
            'const std::vector<Tensor>&':
540
            'const std::vector<const phi::DenseTensor*>&',
H
hong 已提交
541 542
            'const paddle::optional<Tensor&>':
            'paddle::optional<const phi::DenseTensor&>',
543 544
            'const paddle::optional<Tensor>&':
            'const paddle::optional<phi::DenseTensor>&',
545 546
            'const paddle::optional<std::vector<Tensor>>&':
            'paddle::optional<const std::vector<phi::DenseTensor>&>'
547
        }
548
        dense_out_trans_map = {
549 550
            'Tensor': 'phi::DenseTensor*',
            'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&'
551
        }
552 553 554 555 556 557 558
        sr_input_trans_map = {
            'const Tensor&':
            'const phi::SelectedRows&',
            'const paddle::optional<Tensor>&':
            'const paddle::optional<phi::SelectedRows>&'
        }
        sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
559 560 561 562 563 564 565 566 567 568 569 570 571
        input_names = self.inputs['names']
        input_infos = self.inputs['input_info']
        kernel_args_type_list = ['const platform::DeviceContext&']

        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        input_tensor_code = ""
        for i, input_name in enumerate(input_names):
            # set input code
            if input_name in kernel_param:
572 573 574 575 576 577 578 579 580 581 582
                # input is dense tensor
                if kernel_tensor_type is None or kernel_tensor_type[0][
                        kernel_param.index(input_name)] == 'dense':
                    trans_flag = "{}"
                    if input_name in self.data_transform['skip_transform']:
                        trans_flag = "{true}"
                    elif input_name in self.data_transform[
                            'support_trans_dtype']:
                        trans_flag = "{false, true}"
                    if input_name in self.optional_vars:
                        input_tensor_code = input_tensor_code + f"""
583
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
584

585 586 587 588
                    else:
                        if self.inputs['input_info'][
                                input_name] == "const Tensor&":
                            input_tensor_code = input_tensor_code + f"""
589
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
590

591 592 593
                        elif self.inputs['input_info'][
                                input_name] == "const std::vector<Tensor>&":
                            input_tensor_code = input_tensor_code + f"""
594 595 596 597 598 599
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});
{code_indent}  std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent}  for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent}    {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent}  }}"""

600 601 602 603
                        else:
                            # do nothing
                            pass
                else:  # input is selected_rows
604
                    input_tensor_code = input_tensor_code + f"""
605 606 607 608 609 610
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
            else:
                if input_name in self.infer_meta['param']:
                    if input_name in self.optional_vars:
                        input_tensor_code = input_tensor_code + f"""
{code_indent}  paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
611

612 613 614
                    else:
                        input_tensor_code = input_tensor_code + f"""
{code_indent}  auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
615

616
        kernel_args = ["*dev_ctx"]
617 618
        for param in kernel_param:
            if param in input_names:
619
                if param in self.optional_vars:
620
                    kernel_args.append(PREFIX_TENSOR_NAME + param)
621
                else:
622
                    if self.inputs['input_info'][param] == "const Tensor&":
623
                        kernel_args.append("*" + PREFIX_TENSOR_NAME + param)
624
                    elif self.inputs['input_info'][
625
                            param] == "const std::vector<Tensor>&":
626
                        kernel_args.append(PREFIX_TENSOR_NAME + param)
627 628 629
                    else:
                        # do nothing
                        pass
630 631 632
                # input is dense tensor
                if kernel_tensor_type is None or kernel_tensor_type[0][
                        kernel_param.index(param)] == 'dense':
633
                    kernel_args_type_list.append(
634 635 636 637
                        dense_input_trans_map[input_infos[param]])
                else:  # input is selected_rows
                    kernel_args_type_list.append(
                        sr_input_trans_map[input_infos[param]])
638 639
            elif param in attr_names:
                # set attr for kernel_context
640 641 642
                if 'IntArray' in self.attrs['attr_info'][param][0]:
                    kernel_args_type_list.append('const phi::IntArray&')
                    param = 'phi::IntArray(' + param + ')'
643
                elif 'Scalar' in self.attrs['attr_info'][param][0]:
644 645
                    kernel_args_type_list.append('const phi::Scalar&')
                    param = 'phi::Scalar(' + param + ')'
646
                else:
647 648
                    kernel_args_type_list.append(
                        self.attrs['attr_info'][param][0])
649
                kernel_args.append(param)
650
            elif isinstance(param, bool):
651
                kernel_args.append(str(param).lower())
652
            else:
653
                kernel_args.append(str(param))
654

655 656 657 658 659 660
        for i, out_type in enumerate(self.outputs['types']):
            # output is dense tensor
            if kernel_tensor_type is None or kernel_tensor_type[1][i] == 'dense':
                kernel_args_type_list.append(dense_out_trans_map[out_type])
            else:  # output is selected_rows
                kernel_args_type_list.append(sr_out_trans_map[out_type])
661 662 663

        kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

664
        return input_tensor_code, ", ".join(kernel_args), kernel_signature
665

666 667
    # Override by child class
    def gene_return_code(self):
668
        return "return api_output;"
669

670
    # Override by child class
671
    def gene_output(self,
672 673 674
                    out_dtype_list,
                    out_tensor_type_list=None,
                    code_indent='',
675
                    inplace_flag=False):
676 677
        return None, None, None

678 679
    def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
        kernel_dispatch = self.kernel['dispatch'][kernel_name]
680
        input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
681 682
            kernel_dispatch, code_indent)
        out_tensor_type_list = kernel_dispatch[1] if kernel_dispatch else None
683
        outputs_args, kernel_output_names, output_create = self.gene_output(
684 685
            self.outputs['types'], out_tensor_type_list, code_indent,
            inplace_flag)
Z
zyfncg 已提交
686
        cudnn_args = '' if self.kernel[
687
            'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn']
688
        return f"""
F
From00 已提交
689
{code_indent}  VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
690
{code_indent}  const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
691 692
{code_indent}      "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
{code_indent}  VLOG(6) << "{kernel_name} kernel: " << kernel;
693 694 695 696 697 698 699 700

{code_indent}  auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{self.gene_infer_meta(kernel_output_names, code_indent)}

{code_indent}  using kernel_signature = {kernel_signature};
{code_indent}  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
701
{code_indent}  {{
702
{code_indent}    paddle::platform::RecordEvent kernel_record_event(\"{kernel_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
703 704
{code_indent}    (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent}  }}
705

706
{code_indent}  {self.gene_return_code()}"""
707

708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
    def get_condition_code(self, kernel_name):
        assert self.kernel['dispatch'][kernel_name], \
                f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'scale' in api.yaml."
        input_types = self.kernel['dispatch'][kernel_name][0]
        condition_list = []
        for i, in_type in enumerate(input_types):
            if in_type == "dense":
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_dense_tensor())"
                    )
                else:
                    condition_list.append(
                        f"{self.inputs['names'][i]}.is_dense_tensor()")
            else:
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_selected_rows())"
                    )
                else:
                    condition_list.append(
                        f"{self.inputs['names'][i]}.is_selected_rows()")
        return " && ".join(condition_list)
731

732 733 734 735 736 737
    def gene_dispatch_code(self, kernel_name, inplace_flag=False):
        return f"""
  if ({self.get_condition_code(kernel_name)}) {{
{self.gen_kernel_code(kernel_name, '  ', inplace_flag)}
  }}
"""
738

739
    def gene_base_api_code(self, inplace_flag=False):
740 741 742
        api_func_name = self.get_api_func_name()
        if inplace_flag and api_func_name[-1] != '_':
            api_func_name += '_'
743
        api_code = f"""
744
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
745
{self.gene_kernel_select()}
746
"""
747

748 749 750 751 752
        if len(self.kernel['func']) > 1:
            kernel_dispatch_code = ''
            for kernel_name in self.kernel['func']:
                kernel_dispatch_code += self.gene_dispatch_code(
                    kernel_name, inplace_flag)
753
            return api_code + f"""
754 755 756
{kernel_dispatch_code}
  PADDLE_THROW(phi::errors::Unimplemented(
          "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
757
}}
758
"""
759
        else:
760 761
            return api_code + self.gen_kernel_code(self.kernel['func'][0], '',
                                                   inplace_flag) + """
762
}
763 764
"""

765 766
    def gene_invoke_code(self, invoke_code, params_code):
        return f"""
767
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
768 769 770
  return {invoke_code};
}}"""

771 772 773
    def gene_api_code(self):
        if self.is_base_api:
            api_code = self.gene_base_api_code()
774
            if len(self.inplace_map) > 0:
Z
zyfncg 已提交
775 776
                if self.api[-1] == '_':
                    api_code = ""
777 778 779
                api_code = api_code + self.gene_base_api_code(inplace_flag=True)
            return api_code

780
        else:
781 782
            invoke_func_name = self.invoke.split('(')[0].strip()
            if invoke_func_name in self.attrs['names']:
783
                # Adjust the param whose name is same with api invoked.
784
                pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
785 786 787 788 789 790 791

                def adjust_name(matched):
                    matched_str = matched.group()
                    return matched_str[0:-1] + '_val' + matched_str[-1]

                invoke_code = re.sub(pattern, adjust_name, self.invoke)
                params_code = re.sub(pattern, adjust_name,
792
                                     self.get_define_args())
793 794
            else:
                invoke_code = self.invoke
795 796
                params_code = self.get_define_args()
            return self.gene_invoke_code(invoke_code, params_code)