kernel_registry.h 66.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstring>
18
#include <string>
19 20 21 22 23
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <vector>

24
#include "paddle/phi/core/custom_kernel.h"
25 26 27 28
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/type_defs.h"
29

30
#include "paddle/phi/core/enforce.h"
31

32
namespace phi {
33

34 35 36
#define BACKEND(arg__) phi::Backend::arg__
#define DATALAYOUT(arg__) phi::DataLayout::arg__
#define DATATYPE(arg__) phi::DataType::arg__
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

template <typename Func>
struct KernelArgsParseFunctor;

template <typename Return_, typename... Args_>
struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
  using Args = std::tuple<Args_...>;
  enum : std::size_t { Arity = sizeof...(Args_) };
  using Indices = std::make_index_sequence<Arity>;
  template <std::size_t Index>
  using Arg = typename std::tuple_element<Index, Args>::type;

  static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) {
    // TODO(chenweihang): The fluid Tensor's default layout is NCHW,
    // it is not same as kernel's layout, we should fix this error on
    // fluid Tensor
53 54
    auto default_tensor_layout = phi::DataLayout::NCHW;
    if (default_key.layout() != phi::DataLayout::ANY) {
55 56 57 58 59 60 61
      default_tensor_layout = default_key.layout();
    }
    auto args_type = ParseArgType(Indices{});
    for (auto arg_type : args_type) {
      if (arg_type == std::type_index(typeid(const CPUContext&))
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          ||
62
          arg_type == std::type_index(typeid(const GPUContext&))) {
63 64 65
#elif defined(PADDLE_WITH_XPU)
          ||
          arg_type == std::type_index(typeid(const XPUContext&))) {
66 67 68
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
          ||
          arg_type == std::type_index(typeid(const CustomContext&))) {
69 70 71 72 73
#else
              ) {
#endif
        // do nothing, skip context arg now
      } else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
H
hong 已提交
74 75 76 77
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
78 79
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const DenseTensor&>))) {
H
hong 已提交
80 81 82 83
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
84 85 86 87 88 89 90
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<
                                     const std::vector<const DenseTensor*>>))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
H
hong 已提交
91 92 93 94 95 96
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const SelectedRows&>))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
97 98
      } else if (arg_type == std::type_index(typeid(
                                 const std::vector<const DenseTensor*>&))) {
H
hong 已提交
99 100 101 102
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
103
      } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
H
hong 已提交
104 105 106 107
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
Z
zhangkaihuo 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
      } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const SparseCooTensor&>))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
      } else if (arg_type == std::type_index(typeid(const SparseCsrTensor&))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const SparseCsrTensor&>))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
130
      } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
H
hong 已提交
131 132 133 134
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
135 136
      } else if (arg_type ==
                 std::type_index(typeid(std::vector<DenseTensor*>))) {
H
hong 已提交
137 138 139 140
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
141
      } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
H
hong 已提交
142 143 144 145
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
Z
zhangkaihuo 已提交
146 147 148 149 150 151 152 153 154 155
      } else if (arg_type == std::type_index(typeid(SparseCooTensor*))) {
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
      } else if (arg_type == std::type_index(typeid(SparseCsrTensor*))) {
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
      } else {
        // Attribute deal with
        // TODO(chenweihang): now here allow any types of attribute, maybe
        // should add limits here
        args_def->AppendAttribute(arg_type);
      }
    }
  }

 private:
  template <std::size_t... INDEX>
  static std::vector<std::type_index> ParseArgType(
      std::index_sequence<INDEX...>) {
    return {std::type_index(typeid(Arg<INDEX>))...};
  }
};

173
// NOTE: used for making a difference between inner or outer registration.
174
enum class RegType : uint8_t {
175 176
  INNER = 0,
  OUTER,
177 178
};

179 180
// TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor
181 182
struct KernelRegistrar {
 public:
183 184 185
  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
186 187 188 189
                  DataLayout layout,
                  DataType dtype,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
190 191
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
192 193 194
    ConstructKernel(reg_type,
                    kernel_name_cstr,
                    backend_cstr,
195 196 197 198
                    layout,
                    dtype,
                    args_parse_fn,
                    args_def_fn,
199 200
                    kernel_fn,
                    variadic_kernel_fn);
201 202
  }

203 204 205
  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
206 207 208
                  DataLayout layout,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
209 210
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
211 212 213
    for (size_t dtype = static_cast<size_t>(DataType::BOOL);
         dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
         dtype++) {
214 215 216 217 218 219 220
      // NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
      // of these type.
      if (dtype == static_cast<size_t>(DataType::UINT32) ||
          dtype == static_cast<size_t>(DataType::UINT64) ||
          dtype == static_cast<size_t>(DataType::UINT16)) {
        continue;
      }
J
Jack Zhou 已提交
221 222 223 224 225 226 227 228
      // NOTE(zhoushunjie): Only the strings kernels can support pstring dtype
      constexpr char strings_kernels_prefix[] = "strings_";
      if (dtype == static_cast<size_t>(DataType::PSTRING) &&
          strncmp(kernel_name_cstr,
                  strings_kernels_prefix,
                  strlen(strings_kernels_prefix))) {
        continue;
      }
229 230 231
      ConstructKernel(reg_type,
                      kernel_name_cstr,
                      backend_cstr,
232 233 234 235
                      layout,
                      static_cast<DataType>(dtype),
                      args_parse_fn,
                      args_def_fn,
236 237
                      kernel_fn,
                      variadic_kernel_fn);
238 239 240 241
    }
  }

 private:
242 243 244
  void ConstructKernel(RegType reg_type,
                       const char* kernel_name_cstr,
                       const char* backend_cstr,
245 246 247 248
                       DataLayout layout,
                       DataType dtype,
                       KernelArgsParseFn args_parse_fn,
                       KernelArgsDefFn args_def_fn,
249 250
                       KernelFn kernel_fn,
                       void* variadic_kernel_fn) {
Y
YuanRisheng 已提交
251
    std::string kernel_name(kernel_name_cstr);
252 253
    KernelKey kernel_key(
        paddle::experimental::StringToBackend(backend_cstr), layout, dtype);
254
    Kernel kernel(kernel_fn, variadic_kernel_fn);
255
    args_parse_fn(kernel_key, kernel.mutable_args_def());
256
    args_def_fn(kernel_key, &kernel);
257
    if (reg_type == RegType::INNER) {
258 259
      KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
    } else {
260 261
      CustomKernelMap::Instance().RegisterCustomKernel(
          kernel_name, kernel_key, kernel);
262
    }
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  }
};

/**
 * Reference:
 *
 *   https://stackoverflow.com/questions/1872220/is-it-possible-to-iterate-over-arguments-in-variadic-macros
 *   https://stackoverflow.com/questions/9183993/msvc-variadic-macro-expansion?rq=1
 *   https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
 *
 * Very carefully tiptoeing around an MSVC bug where it improperly expands
 * __VA_ARGS__ as a single token in argument lists.  See these URLs for details:
 *
 *   http://connect.microsoft.com/VisualStudio/feedback/details/380090/variadic-macro-replacement
 *   http://cplusplus.co.il/2010/07/17/variadic-macro-to-count-number-of-arguments/#comment-644
 */
279 280 281
#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
#define _PD_ARG_N_EXPAND(                                                     \
282 283
    _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
  N
284 285
#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
286

287
/** PD_REGISTER_KERNEL
288 289 290 291 292
 *
 * The most frequently used kernel registration macro, used for kernel
 * registration with only data type as template parameter, and the function
 * pointer of the corresponding data type is automatically instantiated
 * during registration.
293
 *
294
 * Note: `2TA` means `2 template argument`
295
 */
296
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
297
  _PD_REGISTER_KERNEL(::phi::RegType::INNER,                                  \
298 299 300 301 302 303 304
                      kernel_name,                                            \
                      backend,                                                \
                      ::phi::backend##Context,                                \
                      layout,                                                 \
                      meta_kernel_fn,                                         \
                      __VA_ARGS__)

305
#define _PD_REGISTER_KERNEL(                                               \
306
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)  \
307
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                       \
308 309
      PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_REGISTER_KERNEL must be called in global namespace.");           \
310
  PD_EXPAND(_PD_REGISTER_2TA_KERNEL(reg_type,                              \
311 312 313 314 315 316
                                    kernel_name,                           \
                                    backend,                               \
                                    context,                               \
                                    layout,                                \
                                    meta_kernel_fn,                        \
                                    __VA_ARGS__))
317

318
#ifndef _WIN32
319
#define _PD_REGISTER_2TA_KERNEL(                                            \
320
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)   \
321 322
  PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__);   \
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
323
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
324
  PD_KERNEL_REGISTRAR_INIT(                                                 \
325
      reg_type,                                                             \
326 327
      kernel_name,                                                          \
      backend,                                                              \
328
      context,                                                              \
329
      layout,                                                               \
330
      &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
331 332
      meta_kernel_fn,                                                       \
      __VA_ARGS__);                                                         \
333
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
334
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
335 336 337 338 339 340 341 342 343 344 345
#else
/**
 * `template decltype(fn) fn` can work on gcc and clang,
 * but msvc will failed, error like:
 *
 *   error C2206: typedef cannot be used for function definition
 *
 * reference:
 *
 *   https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
 *
346
 * And msvc can work without template instantiation
347
 */
348
#define _PD_REGISTER_2TA_KERNEL(                                            \
349
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)   \
350
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
351
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
352
  PD_EXPAND(PD_KERNEL_REGISTRAR_INIT(                                       \
353
      reg_type,                                                             \
354 355
      kernel_name,                                                          \
      backend,                                                              \
356
      context,                                                              \
357
      layout,                                                               \
358
      &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
359
      meta_kernel_fn,                                                       \
360
      __VA_ARGS__));                                                        \
361
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
362
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
363 364
#endif

365 366 367
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \
  _PD_KERNEL_INSTANTIATION(                                            \
      PD_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__)
368

369 370
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \
  PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N)                             \
371 372
  (meta_kernel_fn, backend, context, __VA_ARGS__)

373
#define _PD_KERNEL_INSTANTIATION_1(              \
374 375 376
    meta_kernel_fn, backend, context, cpp_dtype) \
  template decltype(                             \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>
377
#define _PD_KERNEL_INSTANTIATION_2(                                           \
378 379 380
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
381
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(                                       \
382
      meta_kernel_fn, backend, context, __VA_ARGS__))
383
#define _PD_KERNEL_INSTANTIATION_3(                                           \
384 385 386
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
387
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(                                       \
388
      meta_kernel_fn, backend, context, __VA_ARGS__))
389
#define _PD_KERNEL_INSTANTIATION_4(                                           \
390 391 392
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
393
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(                                       \
394
      meta_kernel_fn, backend, context, __VA_ARGS__))
395
#define _PD_KERNEL_INSTANTIATION_5(                                           \
396 397 398
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
399
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(                                       \
400
      meta_kernel_fn, backend, context, __VA_ARGS__))
401
#define _PD_KERNEL_INSTANTIATION_6(                                           \
402 403 404
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
405
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(                                       \
406
      meta_kernel_fn, backend, context, __VA_ARGS__))
407
#define _PD_KERNEL_INSTANTIATION_7(                                           \
408 409 410
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
411
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(                                       \
412
      meta_kernel_fn, backend, context, __VA_ARGS__))
413
#define _PD_KERNEL_INSTANTIATION_8(                                           \
414 415 416
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
417
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(                                       \
418
      meta_kernel_fn, backend, context, __VA_ARGS__))
419
#define _PD_KERNEL_INSTANTIATION_9(                                           \
420 421 422
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
423
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(                                       \
424
      meta_kernel_fn, backend, context, __VA_ARGS__))
425
#define _PD_KERNEL_INSTANTIATION_10(                                          \
426 427 428
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
429
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(                                       \
430
      meta_kernel_fn, backend, context, __VA_ARGS__))
431
#define _PD_KERNEL_INSTANTIATION_11(                                          \
432 433 434
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
435
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(                                      \
436
      meta_kernel_fn, backend, context, __VA_ARGS__))
437
#define _PD_KERNEL_INSTANTIATION_12(                                          \
438 439 440
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
441
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(                                      \
442
      meta_kernel_fn, backend, context, __VA_ARGS__))
443
#define _PD_KERNEL_INSTANTIATION_13(                                          \
444 445 446
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
447
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(                                      \
448
      meta_kernel_fn, backend, context, __VA_ARGS__))
449
#define _PD_KERNEL_INSTANTIATION_14(                                          \
450 451 452
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
453
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(                                      \
454
      meta_kernel_fn, backend, context, __VA_ARGS__))
455
#define _PD_KERNEL_INSTANTIATION_15(                                          \
456 457 458
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
459
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(                                      \
460 461
      meta_kernel_fn, backend, context, __VA_ARGS__))

462
#define PD_KERNEL_REGISTRAR_INIT(reg_type,                   \
463 464 465 466 467 468 469
                                 kernel_name,                \
                                 backend,                    \
                                 context,                    \
                                 layout,                     \
                                 args_def_fn,                \
                                 meta_kernel_fn,             \
                                 ...)                        \
470
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \
471 472 473 474 475 476 477
                                      reg_type,              \
                                      kernel_name,           \
                                      backend,               \
                                      context,               \
                                      layout,                \
                                      args_def_fn,           \
                                      meta_kernel_fn,        \
478
                                      __VA_ARGS__))
479 480 481 482 483

// clang-format off

/* The =pre-commit always treats this macro into the wrong format,
  and multi-line macros cannot be skipped with NOLINT.*/
484
#define _PD_KERNEL_REGISTRAR_INIT(N,                       \
485
                                  reg_type,                \
486 487
                                  kernel_name,             \
                                  backend,                 \
488
                                  context,                 \
489 490 491 492
                                  layout,                  \
                                  args_def_fn,             \
                                  meta_kernel_fn,          \
                                  ...)                     \
493
  PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
494
    reg_type,                                              \
495 496
    kernel_name,                                           \
    backend,                                               \
497
    context,                                               \
498
    layout,                                                \
499
    PD_ID,                                                 \
500 501 502
    args_def_fn,                                           \
    meta_kernel_fn,                                        \
    __VA_ARGS__))
503 504 505

// clang-format on

506
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type,                                 \
507 508 509 510 511 512 513 514
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype)                                \
515
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
516 517 518 519 520 521 522 523 524
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
525 526
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
527
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
528
#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type,                                 \
529 530 531 532 533 534 535 536 537
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
538
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
539 540 541 542 543 544 545 546 547
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
548 549 550
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type,                             \
551 552 553 554
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
555
                                        PD_ID,                                \
556 557
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
558
                                        __VA_ARGS__))
559
#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type,                                 \
560 561 562 563 564 565 566 567 568
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
569
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
570 571 572 573 574 575 576 577 578
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
579 580 581
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type,                             \
582 583 584 585
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
586
                                        PD_ID,                                \
587 588
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
589
                                        __VA_ARGS__))
590
#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type,                                 \
591 592 593 594 595 596 597 598 599
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
600
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
601 602 603 604 605 606 607 608 609
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
610 611 612
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type,                             \
613 614 615 616
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
617
                                        PD_ID,                                \
618 619
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
620
                                        __VA_ARGS__))
621
#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type,                                 \
622 623 624 625 626 627 628 629 630
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
631
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
632 633 634 635 636 637 638 639 640
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
641 642 643
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type,                             \
644 645 646 647
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
648
                                        PD_ID,                                \
649 650
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
651
                                        __VA_ARGS__))
652
#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type,                                 \
653 654 655 656 657 658 659 660 661
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
662
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
663 664 665 666 667 668 669 670 671
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
672 673 674
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type,                             \
675 676 677 678
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
679
                                        PD_ID,                                \
680 681
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
682
                                        __VA_ARGS__))
683
#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type,                                 \
684 685 686 687 688 689 690 691 692
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
693
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
694 695 696 697 698 699 700 701 702
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
703 704 705
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type,                             \
706 707 708 709
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
710
                                        PD_ID,                                \
711 712
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
713
                                        __VA_ARGS__))
714
#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type,                                 \
715 716 717 718 719 720 721 722 723
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
724
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
725 726 727 728 729 730 731 732 733
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
734 735 736
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type,                             \
737 738 739 740
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
741
                                        PD_ID,                                \
742 743
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
744
                                        __VA_ARGS__))
745
#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type,                                 \
746 747 748 749 750 751 752 753 754
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
755
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
756 757 758 759 760 761 762 763 764
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
765 766 767
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type,                             \
768 769 770 771
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
772
                                        PD_ID,                                \
773 774
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
775
                                        __VA_ARGS__))
776
#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type,                                \
777 778 779 780 781 782 783 784 785
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
786
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
787 788 789 790 791 792 793 794 795
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
796 797 798
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type,                             \
799 800 801 802
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
803
                                        PD_ID,                                \
804 805
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
806
                                        __VA_ARGS__))
807
#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type,                                \
808 809 810 811 812 813 814 815 816
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
817
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
818 819 820 821 822 823 824 825 826
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
827 828 829
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type,                            \
830 831 832 833
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
834
                                         PD_ID,                               \
835 836
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
837
                                         __VA_ARGS__))
838
#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type,                                \
839 840 841 842 843 844 845 846 847
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
848
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
849 850 851 852 853 854 855 856 857
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
858 859 860
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type,                            \
861 862 863 864
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
865
                                         PD_ID,                               \
866 867
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
868
                                         __VA_ARGS__))
869
#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type,                                \
870 871 872 873 874 875 876 877 878
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
879
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
880 881 882 883 884 885 886 887 888
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
889 890 891
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type,                            \
892 893 894 895
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
896
                                         PD_ID,                               \
897 898
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
899
                                         __VA_ARGS__))
900
#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type,                                \
901 902 903 904 905 906 907 908 909
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
910
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
911 912 913 914 915 916 917 918 919
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
920 921 922
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type,                            \
923 924 925 926
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
927
                                         PD_ID,                               \
928 929
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
930
                                         __VA_ARGS__))
931
#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type,                                \
932 933 934 935 936 937 938 939 940
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
941
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
942 943 944 945 946 947 948 949 950
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
951 952 953
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type,                            \
954 955 956 957
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
958
                                         PD_ID,                               \
959 960
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
961
                                         __VA_ARGS__))
962
/** PD_REGISTER_GENERAL_KERNEL
963 964 965 966 967
 *
 * Basic Kernel register marco, used to register a instantiated kernel function
 * with one template argument.
 */

968
#define PD_REGISTER_GENERAL_KERNEL(                 \
969
    kernel_name, backend, layout, kernel_fn, dtype) \
970
  _PD_REGISTER_GENERAL_KERNEL(                      \
971
      ::phi::RegType::INNER, kernel_name, backend, layout, kernel_fn, dtype)
972

973
#define _PD_REGISTER_GENERAL_KERNEL(                                         \
974
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)                \
975
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
976 977 978
      PD_REGISTER_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
  __PD_REGISTER_GENERAL_KERNEL(                                              \
979
      reg_type, kernel_name, backend, layout, kernel_fn, dtype)
980 981

#ifndef _WIN32
982
#define __PD_REGISTER_GENERAL_KERNEL(                                       \
983
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)               \
984
  template decltype(kernel_fn) kernel_fn;                                   \
985
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
986 987
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
988
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
989
          reg_type,                                                         \
990
          #kernel_name,                                                     \
991
          #backend,                                                         \
992
          DATALAYOUT(layout),                                               \
993
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
994 995 996
          &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PHI_KERNEL(kernel_fn),                                            \
          PHI_VARIADIC_KERNEL(kernel_fn));                                  \
997 998 999
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
1000
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
1001
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
1002
#else
1003
#define __PD_REGISTER_GENERAL_KERNEL(                                       \
1004
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)               \
1005
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
1006 1007
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
1008
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
1009
          reg_type,                                                         \
1010
          #kernel_name,                                                     \
1011
          #backend,                                                         \
1012
          DATALAYOUT(layout),                                               \
1013
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
1014 1015 1016
          &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PHI_KERNEL(kernel_fn),                                            \
          PHI_VARIADIC_KERNEL(kernel_fn));                                  \
1017 1018 1019
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
1020
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
1021
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
1022 1023
#endif

1024
/** PD_DECLARE_KERNEL
1025 1026 1027 1028
 *
 * Used to export the symbols of the file where the kernel is located,
 * to avoid being removed by linker
 */
1029
#define PD_DECLARE_KERNEL(kernel_name, backend, layout)                   \
1030
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                      \
1031 1032
      PD_DECLARE_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_DECLARE_KERNEL must be called in global namespace.");           \
1033 1034 1035 1036
  extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
  UNUSED static int                                                       \
      __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout =  \
          TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
1037

1038
/** PD_REGISTER_BUILTIN_KERNEL
1039 1040 1041 1042
 *
 * Used to register kernels for built-in backends.
 * Support CPU GPU XPU.
 */
1043 1044
#define PD_REGISTER_BUILTIN_KERNEL(                    \
    kernel_name, backend, layout, meta_kernel_fn, ...) \
1045
  _PD_REGISTER_KERNEL(::phi::RegType::OUTER,           \
1046 1047 1048 1049 1050
                      kernel_name,                     \
                      backend,                         \
                      ::phi::backend##Context,         \
                      layout,                          \
                      meta_kernel_fn,                  \
1051 1052
                      __VA_ARGS__)

1053
/** PD_REGISTER_PLUGIN_KERNEL
1054 1055 1056 1057
 *
 * Used to register kernels for plug-in backends.
 * Support user-defined backend such as 'Ascend910'.
 */
1058
#define PD_REGISTER_PLUGIN_KERNEL(                     \
1059
    kernel_name, backend, layout, meta_kernel_fn, ...) \
1060
  _PD_REGISTER_KERNEL(::phi::RegType::OUTER,           \
1061 1062 1063 1064 1065 1066 1067
                      kernel_name,                     \
                      backend,                         \
                      ::phi::CustomContext,            \
                      layout,                          \
                      meta_kernel_fn,                  \
                      __VA_ARGS__)

1068
}  // namespace phi