kernel_registry.h 64.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstring>
18
#include <string>
19 20 21 22 23
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <vector>

24
#include "paddle/phi/core/custom_kernel.h"
25 26 27 28
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/type_defs.h"
29

30
#include "paddle/phi/core/enforce.h"
31

32
namespace phi {
33

34 35 36
#define BACKEND(arg__) phi::Backend::arg__
#define DATALAYOUT(arg__) phi::DataLayout::arg__
#define DATATYPE(arg__) phi::DataType::arg__
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

template <typename Func>
struct KernelArgsParseFunctor;

template <typename Return_, typename... Args_>
struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
  using Args = std::tuple<Args_...>;
  enum : std::size_t { Arity = sizeof...(Args_) };
  using Indices = std::make_index_sequence<Arity>;
  template <std::size_t Index>
  using Arg = typename std::tuple_element<Index, Args>::type;

  static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) {
    // TODO(chenweihang): The fluid Tensor's default layout is NCHW,
    // it is not same as kernel's layout, we should fix this error on
    // fluid Tensor
53 54
    auto default_tensor_layout = phi::DataLayout::NCHW;
    if (default_key.layout() != phi::DataLayout::ANY) {
55 56 57 58 59 60 61
      default_tensor_layout = default_key.layout();
    }
    auto args_type = ParseArgType(Indices{});
    for (auto arg_type : args_type) {
      if (arg_type == std::type_index(typeid(const CPUContext&))
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          ||
62
          arg_type == std::type_index(typeid(const GPUContext&))) {
63 64 65
#elif defined(PADDLE_WITH_XPU)
          ||
          arg_type == std::type_index(typeid(const XPUContext&))) {
66 67 68
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
          ||
          arg_type == std::type_index(typeid(const CustomContext&))) {
69 70 71 72 73
#else
              ) {
#endif
        // do nothing, skip context arg now
      } else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
H
hong 已提交
74 75 76 77
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
78 79
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const DenseTensor&>))) {
H
hong 已提交
80 81 82 83
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
H
hong 已提交
84 85 86 87 88 89
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const SelectedRows&>))) {
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
90 91
      } else if (arg_type == std::type_index(typeid(
                                 const std::vector<const DenseTensor*>&))) {
H
hong 已提交
92 93 94 95
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
96
      } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
H
hong 已提交
97 98 99 100
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
101
      } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
H
hong 已提交
102 103 104 105
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
106 107
      } else if (arg_type ==
                 std::type_index(typeid(std::vector<DenseTensor*>))) {
H
hong 已提交
108 109 110 111
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
112
      } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
H
hong 已提交
113 114 115 116
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
      } else {
        // Attribute deal with
        // TODO(chenweihang): now here allow any types of attribute, maybe
        // should add limits here
        args_def->AppendAttribute(arg_type);
      }
    }
  }

 private:
  template <std::size_t... INDEX>
  static std::vector<std::type_index> ParseArgType(
      std::index_sequence<INDEX...>) {
    return {std::type_index(typeid(Arg<INDEX>))...};
  }
};

134
// NOTE: used for making a difference between inner or outer registration.
135
enum class RegType : uint8_t {
136 137
  INNER = 0,
  OUTER,
138 139
};

140 141
// TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor
142 143
struct KernelRegistrar {
 public:
144 145 146
  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
147 148 149 150
                  DataLayout layout,
                  DataType dtype,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
151 152
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
153 154 155
    ConstructKernel(reg_type,
                    kernel_name_cstr,
                    backend_cstr,
156 157 158 159
                    layout,
                    dtype,
                    args_parse_fn,
                    args_def_fn,
160 161
                    kernel_fn,
                    variadic_kernel_fn);
162 163
  }

164 165 166
  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
167 168 169
                  DataLayout layout,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
170 171
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
172 173 174
    for (size_t dtype = static_cast<size_t>(DataType::BOOL);
         dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
         dtype++) {
175 176 177 178 179 180 181
      // NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
      // of these type.
      if (dtype == static_cast<size_t>(DataType::UINT32) ||
          dtype == static_cast<size_t>(DataType::UINT64) ||
          dtype == static_cast<size_t>(DataType::UINT16)) {
        continue;
      }
182 183 184
      ConstructKernel(reg_type,
                      kernel_name_cstr,
                      backend_cstr,
185 186 187 188
                      layout,
                      static_cast<DataType>(dtype),
                      args_parse_fn,
                      args_def_fn,
189 190
                      kernel_fn,
                      variadic_kernel_fn);
191 192 193 194
    }
  }

 private:
195 196 197
  void ConstructKernel(RegType reg_type,
                       const char* kernel_name_cstr,
                       const char* backend_cstr,
198 199 200 201
                       DataLayout layout,
                       DataType dtype,
                       KernelArgsParseFn args_parse_fn,
                       KernelArgsDefFn args_def_fn,
202 203
                       KernelFn kernel_fn,
                       void* variadic_kernel_fn) {
Y
YuanRisheng 已提交
204
    std::string kernel_name(kernel_name_cstr);
205 206
    KernelKey kernel_key(
        paddle::experimental::StringToBackend(backend_cstr), layout, dtype);
207
    Kernel kernel(kernel_fn, variadic_kernel_fn);
208
    args_parse_fn(kernel_key, kernel.mutable_args_def());
209
    args_def_fn(kernel_key, &kernel);
210
    if (reg_type == RegType::INNER) {
211 212
      KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
    } else {
213 214
      CustomKernelMap::Instance().RegisterCustomKernel(
          kernel_name, kernel_key, kernel);
215
    }
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  }
};

/**
 * Reference:
 *
 *   https://stackoverflow.com/questions/1872220/is-it-possible-to-iterate-over-arguments-in-variadic-macros
 *   https://stackoverflow.com/questions/9183993/msvc-variadic-macro-expansion?rq=1
 *   https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
 *
 * Very carefully tiptoeing around an MSVC bug where it improperly expands
 * __VA_ARGS__ as a single token in argument lists.  See these URLs for details:
 *
 *   http://connect.microsoft.com/VisualStudio/feedback/details/380090/variadic-macro-replacement
 *   http://cplusplus.co.il/2010/07/17/variadic-macro-to-count-number-of-arguments/#comment-644
 */
232 233 234
#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
#define _PD_ARG_N_EXPAND(                                                     \
235 236
    _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
  N
237 238
#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
239

240
/** PD_REGISTER_KERNEL
241 242 243 244 245
 *
 * The most frequently used kernel registration macro, used for kernel
 * registration with only data type as template parameter, and the function
 * pointer of the corresponding data type is automatically instantiated
 * during registration.
246
 *
247
 * Note: `2TA` means `2 template argument`
248
 */
249
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
250
  _PD_REGISTER_KERNEL(::phi::RegType::INNER,                                  \
251 252 253 254 255 256 257
                      kernel_name,                                            \
                      backend,                                                \
                      ::phi::backend##Context,                                \
                      layout,                                                 \
                      meta_kernel_fn,                                         \
                      __VA_ARGS__)

258
#define _PD_REGISTER_KERNEL(                                               \
259
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)  \
260
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                       \
261 262
      PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_REGISTER_KERNEL must be called in global namespace.");           \
263
  PD_EXPAND(_PD_REGISTER_2TA_KERNEL(reg_type,                              \
264 265 266 267 268 269
                                    kernel_name,                           \
                                    backend,                               \
                                    context,                               \
                                    layout,                                \
                                    meta_kernel_fn,                        \
                                    __VA_ARGS__))
270

271
#ifndef _WIN32
272
#define _PD_REGISTER_2TA_KERNEL(                                            \
273
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)   \
274 275
  PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__);   \
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
276
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
277
  PD_KERNEL_REGISTRAR_INIT(                                                 \
278
      reg_type,                                                             \
279 280
      kernel_name,                                                          \
      backend,                                                              \
281
      context,                                                              \
282
      layout,                                                               \
283
      &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
284 285
      meta_kernel_fn,                                                       \
      __VA_ARGS__);                                                         \
286
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
287
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
288 289 290 291 292 293 294 295 296 297 298
#else
/**
 * `template decltype(fn) fn` can work on gcc and clang,
 * but msvc will failed, error like:
 *
 *   error C2206: typedef cannot be used for function definition
 *
 * reference:
 *
 *   https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
 *
299
 * And msvc can work without template instantiation
300
 */
301
#define _PD_REGISTER_2TA_KERNEL(                                            \
302
    reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...)   \
303
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
304
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
305
  PD_EXPAND(PD_KERNEL_REGISTRAR_INIT(                                       \
306
      reg_type,                                                             \
307 308
      kernel_name,                                                          \
      backend,                                                              \
309
      context,                                                              \
310
      layout,                                                               \
311
      &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
312
      meta_kernel_fn,                                                       \
313
      __VA_ARGS__));                                                        \
314
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
315
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
316 317
#endif

318 319 320
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \
  _PD_KERNEL_INSTANTIATION(                                            \
      PD_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__)
321

322 323
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \
  PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N)                             \
324 325
  (meta_kernel_fn, backend, context, __VA_ARGS__)

326
#define _PD_KERNEL_INSTANTIATION_1(              \
327 328 329
    meta_kernel_fn, backend, context, cpp_dtype) \
  template decltype(                             \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>
330
#define _PD_KERNEL_INSTANTIATION_2(                                           \
331 332 333
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
334
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(                                       \
335
      meta_kernel_fn, backend, context, __VA_ARGS__))
336
#define _PD_KERNEL_INSTANTIATION_3(                                           \
337 338 339
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
340
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(                                       \
341
      meta_kernel_fn, backend, context, __VA_ARGS__))
342
#define _PD_KERNEL_INSTANTIATION_4(                                           \
343 344 345
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
346
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(                                       \
347
      meta_kernel_fn, backend, context, __VA_ARGS__))
348
#define _PD_KERNEL_INSTANTIATION_5(                                           \
349 350 351
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
352
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(                                       \
353
      meta_kernel_fn, backend, context, __VA_ARGS__))
354
#define _PD_KERNEL_INSTANTIATION_6(                                           \
355 356 357
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
358
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(                                       \
359
      meta_kernel_fn, backend, context, __VA_ARGS__))
360
#define _PD_KERNEL_INSTANTIATION_7(                                           \
361 362 363
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
364
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(                                       \
365
      meta_kernel_fn, backend, context, __VA_ARGS__))
366
#define _PD_KERNEL_INSTANTIATION_8(                                           \
367 368 369
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
370
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(                                       \
371
      meta_kernel_fn, backend, context, __VA_ARGS__))
372
#define _PD_KERNEL_INSTANTIATION_9(                                           \
373 374 375
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
376
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(                                       \
377
      meta_kernel_fn, backend, context, __VA_ARGS__))
378
#define _PD_KERNEL_INSTANTIATION_10(                                          \
379 380 381
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
382
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(                                       \
383
      meta_kernel_fn, backend, context, __VA_ARGS__))
384
#define _PD_KERNEL_INSTANTIATION_11(                                          \
385 386 387
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
388
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(                                      \
389
      meta_kernel_fn, backend, context, __VA_ARGS__))
390
#define _PD_KERNEL_INSTANTIATION_12(                                          \
391 392 393
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
394
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(                                      \
395
      meta_kernel_fn, backend, context, __VA_ARGS__))
396
#define _PD_KERNEL_INSTANTIATION_13(                                          \
397 398 399
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
400
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(                                      \
401
      meta_kernel_fn, backend, context, __VA_ARGS__))
402
#define _PD_KERNEL_INSTANTIATION_14(                                          \
403 404 405
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
406
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(                                      \
407
      meta_kernel_fn, backend, context, __VA_ARGS__))
408
#define _PD_KERNEL_INSTANTIATION_15(                                          \
409 410 411
    meta_kernel_fn, backend, context, cpp_dtype, ...)                         \
  template decltype(                                                          \
      meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
412
  PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(                                      \
413 414
      meta_kernel_fn, backend, context, __VA_ARGS__))

415
#define PD_KERNEL_REGISTRAR_INIT(reg_type,                   \
416 417 418 419 420 421 422
                                 kernel_name,                \
                                 backend,                    \
                                 context,                    \
                                 layout,                     \
                                 args_def_fn,                \
                                 meta_kernel_fn,             \
                                 ...)                        \
423
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \
424 425 426 427 428 429 430
                                      reg_type,              \
                                      kernel_name,           \
                                      backend,               \
                                      context,               \
                                      layout,                \
                                      args_def_fn,           \
                                      meta_kernel_fn,        \
431
                                      __VA_ARGS__))
432 433 434 435 436

// clang-format off

/* The =pre-commit always treats this macro into the wrong format,
  and multi-line macros cannot be skipped with NOLINT.*/
437
#define _PD_KERNEL_REGISTRAR_INIT(N,                       \
438
                                  reg_type,                \
439 440
                                  kernel_name,             \
                                  backend,                 \
441
                                  context,                 \
442 443 444 445
                                  layout,                  \
                                  args_def_fn,             \
                                  meta_kernel_fn,          \
                                  ...)                     \
446
  PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
447
    reg_type,                                              \
448 449
    kernel_name,                                           \
    backend,                                               \
450
    context,                                               \
451
    layout,                                                \
452
    PD_ID,                                                 \
453 454 455
    args_def_fn,                                           \
    meta_kernel_fn,                                        \
    __VA_ARGS__))
456 457 458

// clang-format on

459
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type,                                 \
460 461 462 463 464 465 466 467
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype)                                \
468
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
469 470 471 472 473 474 475 476 477
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
478 479
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
480
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
481
#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type,                                 \
482 483 484 485 486 487 488 489 490
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
491
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
492 493 494 495 496 497 498 499 500
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
501 502 503
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type,                             \
504 505 506 507
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
508
                                        PD_ID,                                \
509 510
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
511
                                        __VA_ARGS__))
512
#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type,                                 \
513 514 515 516 517 518 519 520 521
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
522
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
523 524 525 526 527 528 529 530 531
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
532 533 534
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type,                             \
535 536 537 538
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
539
                                        PD_ID,                                \
540 541
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
542
                                        __VA_ARGS__))
543
#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type,                                 \
544 545 546 547 548 549 550 551 552
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
553
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
554 555 556 557 558 559 560 561 562
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
563 564 565
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type,                             \
566 567 568 569
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
570
                                        PD_ID,                                \
571 572
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
573
                                        __VA_ARGS__))
574
#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type,                                 \
575 576 577 578 579 580 581 582 583
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
584
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
585 586 587 588 589 590 591 592 593
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
594 595 596
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type,                             \
597 598 599 600
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
601
                                        PD_ID,                                \
602 603
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
604
                                        __VA_ARGS__))
605
#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type,                                 \
606 607 608 609 610 611 612 613 614
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
615
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
616 617 618 619 620 621 622 623 624
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
625 626 627
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type,                             \
628 629 630 631
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
632
                                        PD_ID,                                \
633 634
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
635
                                        __VA_ARGS__))
636
#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type,                                 \
637 638 639 640 641 642 643 644 645
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
646
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
647 648 649 650 651 652 653 654 655
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
656 657 658
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type,                             \
659 660 661 662
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
663
                                        PD_ID,                                \
664 665
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
666
                                        __VA_ARGS__))
667
#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type,                                 \
668 669 670 671 672 673 674 675 676
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
677
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
678 679 680 681 682 683 684 685 686
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
687 688 689
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type,                             \
690 691 692 693
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
694
                                        PD_ID,                                \
695 696
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
697
                                        __VA_ARGS__))
698
#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type,                                 \
699 700 701 702 703 704 705 706 707
                                    kernel_name,                              \
                                    backend,                                  \
                                    context,                                  \
                                    layout,                                   \
                                    registrar_id,                             \
                                    args_def_fn,                              \
                                    meta_kernel_fn,                           \
                                    cpp_dtype,                                \
                                    ...)                                      \
708
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
709 710 711 712 713 714 715 716 717
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
718 719 720
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type,                             \
721 722 723 724
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
725
                                        PD_ID,                                \
726 727
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
728
                                        __VA_ARGS__))
729
#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type,                                \
730 731 732 733 734 735 736 737 738
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
739
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
740 741 742 743 744 745 746 747 748
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
749 750 751
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type,                             \
752 753 754 755
                                        kernel_name,                          \
                                        backend,                              \
                                        context,                              \
                                        layout,                               \
756
                                        PD_ID,                                \
757 758
                                        args_def_fn,                          \
                                        meta_kernel_fn,                       \
759
                                        __VA_ARGS__))
760
#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type,                                \
761 762 763 764 765 766 767 768 769
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
770
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
771 772 773 774 775 776 777 778 779
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
780 781 782
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type,                            \
783 784 785 786
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
787
                                         PD_ID,                               \
788 789
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
790
                                         __VA_ARGS__))
791
#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type,                                \
792 793 794 795 796 797 798 799 800
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
801
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
802 803 804 805 806 807 808 809 810
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
811 812 813
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type,                            \
814 815 816 817
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
818
                                         PD_ID,                               \
819 820
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
821
                                         __VA_ARGS__))
822
#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type,                                \
823 824 825 826 827 828 829 830 831
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
832
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
833 834 835 836 837 838 839 840 841
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
842 843 844
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type,                            \
845 846 847 848
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
849
                                         PD_ID,                               \
850 851
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
852
                                         __VA_ARGS__))
853
#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type,                                \
854 855 856 857 858 859 860 861 862
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
863
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
864 865 866 867 868 869 870 871 872
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
873 874 875
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type,                            \
876 877 878 879
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
880
                                         PD_ID,                               \
881 882
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
883
                                         __VA_ARGS__))
884
#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type,                                \
885 886 887 888 889 890 891 892 893
                                     kernel_name,                             \
                                     backend,                                 \
                                     context,                                 \
                                     layout,                                  \
                                     registrar_id,                            \
                                     args_def_fn,                             \
                                     meta_kernel_fn,                          \
                                     cpp_dtype,                               \
                                     ...)                                     \
894
  static const ::phi::KernelRegistrar PD_CONCATENATE(                         \
895 896 897 898 899 900 901 902 903
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
      reg_type,                                                               \
      #kernel_name,                                                           \
      #backend,                                                               \
      DATALAYOUT(layout),                                                     \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),           \
      ::phi::KernelArgsParseFunctor<decltype(                                 \
          &meta_kernel_fn<cpp_dtype, context>)>::Parse,                       \
      args_def_fn,                                                            \
904 905 906
      PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),                         \
      PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));               \
  PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type,                            \
907 908 909 910
                                         kernel_name,                         \
                                         backend,                             \
                                         context,                             \
                                         layout,                              \
911
                                         PD_ID,                               \
912 913
                                         args_def_fn,                         \
                                         meta_kernel_fn,                      \
914
                                         __VA_ARGS__))
915
/** PD_REGISTER_GENERAL_KERNEL
916 917 918 919 920
 *
 * Basic Kernel register marco, used to register a instantiated kernel function
 * with one template argument.
 */

921
#define PD_REGISTER_GENERAL_KERNEL(                 \
922
    kernel_name, backend, layout, kernel_fn, dtype) \
923
  _PD_REGISTER_GENERAL_KERNEL(                      \
924
      ::phi::RegType::INNER, kernel_name, backend, layout, kernel_fn, dtype)
925

926
#define _PD_REGISTER_GENERAL_KERNEL(                                         \
927
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)                \
928
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
929 930 931
      PD_REGISTER_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
  __PD_REGISTER_GENERAL_KERNEL(                                              \
932
      reg_type, kernel_name, backend, layout, kernel_fn, dtype)
933 934

#ifndef _WIN32
935
#define __PD_REGISTER_GENERAL_KERNEL(                                       \
936
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)               \
937
  template decltype(kernel_fn) kernel_fn;                                   \
938
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
939 940
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
941
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
942
          reg_type,                                                         \
943
          #kernel_name,                                                     \
944
          #backend,                                                         \
945
          DATALAYOUT(layout),                                               \
946
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
947 948 949
          &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PHI_KERNEL(kernel_fn),                                            \
          PHI_VARIADIC_KERNEL(kernel_fn));                                  \
950 951 952
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
953
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
954
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
955
#else
956
#define __PD_REGISTER_GENERAL_KERNEL(                                       \
957
    reg_type, kernel_name, backend, layout, kernel_fn, dtype)               \
958
  static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
959 960
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
961
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
962
          reg_type,                                                         \
963
          #kernel_name,                                                     \
964
          #backend,                                                         \
965
          DATALAYOUT(layout),                                               \
966
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
967 968 969
          &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PHI_KERNEL(kernel_fn),                                            \
          PHI_VARIADIC_KERNEL(kernel_fn));                                  \
970 971 972
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
973
  void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
974
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
975 976
#endif

977
/** PD_DECLARE_KERNEL
978 979 980 981
 *
 * Used to export the symbols of the file where the kernel is located,
 * to avoid being removed by linker
 */
982
#define PD_DECLARE_KERNEL(kernel_name, backend, layout)                   \
983
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                      \
984 985
      PD_DECLARE_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_DECLARE_KERNEL must be called in global namespace.");           \
986 987 988 989
  extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
  UNUSED static int                                                       \
      __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout =  \
          TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
990

991
/** PD_REGISTER_BUILTIN_KERNEL
992 993 994 995
 *
 * Used to register kernels for built-in backends.
 * Support CPU GPU XPU.
 */
996 997
#define PD_REGISTER_BUILTIN_KERNEL(                    \
    kernel_name, backend, layout, meta_kernel_fn, ...) \
998
  _PD_REGISTER_KERNEL(::phi::RegType::OUTER,           \
999 1000 1001 1002 1003
                      kernel_name,                     \
                      backend,                         \
                      ::phi::backend##Context,         \
                      layout,                          \
                      meta_kernel_fn,                  \
1004 1005
                      __VA_ARGS__)

1006
/** PD_REGISTER_PLUGIN_KERNEL
1007 1008 1009 1010
 *
 * Used to register kernels for plug-in backends.
 * Support user-defined backend such as 'Ascend910'.
 */
1011
#define PD_REGISTER_PLUGIN_KERNEL(                     \
1012
    kernel_name, backend, layout, meta_kernel_fn, ...) \
1013
  _PD_REGISTER_KERNEL(::phi::RegType::OUTER,           \
1014 1015 1016 1017 1018 1019 1020
                      kernel_name,                     \
                      backend,                         \
                      ::phi::CustomContext,            \
                      layout,                          \
                      meta_kernel_fn,                  \
                      __VA_ARGS__)

1021
}  // namespace phi