kernel_registry.h 52.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstring>
18
#include <string>
19 20 21 22 23
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <vector>

24 25 26 27
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/type_defs.h"
28

29
#include "paddle/phi/core/enforce.h"
30

31
namespace phi {
32

33 34 35
#define BACKEND(arg__) phi::Backend::arg__
#define DATALAYOUT(arg__) phi::DataLayout::arg__
#define DATATYPE(arg__) phi::DataType::arg__
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

template <typename Func>
struct KernelArgsParseFunctor;

template <typename Return_, typename... Args_>
struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
  using Args = std::tuple<Args_...>;
  enum : std::size_t { Arity = sizeof...(Args_) };
  using Indices = std::make_index_sequence<Arity>;
  template <std::size_t Index>
  using Arg = typename std::tuple_element<Index, Args>::type;

  static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) {
    // TODO(chenweihang): The fluid Tensor's default layout is NCHW,
    // it is not same as kernel's layout, we should fix this error on
    // fluid Tensor
52 53
    auto default_tensor_layout = phi::DataLayout::NCHW;
    if (default_key.layout() != phi::DataLayout::ANY) {
54 55 56 57 58 59 60
      default_tensor_layout = default_key.layout();
    }
    auto args_type = ParseArgType(Indices{});
    for (auto arg_type : args_type) {
      if (arg_type == std::type_index(typeid(const CPUContext&))
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
          ||
61
          arg_type == std::type_index(typeid(const GPUContext&))) {
62 63 64
#elif defined(PADDLE_WITH_XPU)
          ||
          arg_type == std::type_index(typeid(const XPUContext&))) {
65 66 67 68 69
#else
              ) {
#endif
        // do nothing, skip context arg now
      } else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
H
hong 已提交
70 71 72 73
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
74 75
      } else if (arg_type == std::type_index(typeid(
                                 paddle::optional<const DenseTensor&>))) {
H
hong 已提交
76 77 78 79
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
80 81
      } else if (arg_type ==
                 std::type_index(typeid(const std::vector<DenseTensor>&))) {
H
hong 已提交
82 83 84 85
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
86
      } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
H
hong 已提交
87 88 89 90
        args_def->AppendInput(default_key.backend(),
                              default_tensor_layout,
                              default_key.dtype(),
                              arg_type);
91
      } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
H
hong 已提交
92 93 94 95
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
96 97
      } else if (arg_type ==
                 std::type_index(typeid(std::vector<DenseTensor*>))) {
H
hong 已提交
98 99 100 101
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
102
      } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
H
hong 已提交
103 104 105 106
        args_def->AppendOutput(default_key.backend(),
                               default_tensor_layout,
                               default_key.dtype(),
                               arg_type);
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
      } else {
        // Attribute deal with
        // TODO(chenweihang): now here allow any types of attribute, maybe
        // should add limits here
        args_def->AppendAttribute(arg_type);
      }
    }
  }

 private:
  template <std::size_t... INDEX>
  static std::vector<std::type_index> ParseArgType(
      std::index_sequence<INDEX...>) {
    return {std::type_index(typeid(Arg<INDEX>))...};
  }
};

124 125
// TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor
126 127 128 129 130 131 132 133
struct KernelRegistrar {
 public:
  KernelRegistrar(const char* kernel_name_cstr,
                  Backend backend,
                  DataLayout layout,
                  DataType dtype,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
134 135
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
136 137 138 139 140 141
    ConstructKernel(kernel_name_cstr,
                    backend,
                    layout,
                    dtype,
                    args_parse_fn,
                    args_def_fn,
142 143
                    kernel_fn,
                    variadic_kernel_fn);
144 145 146 147 148 149 150
  }

  KernelRegistrar(const char* kernel_name_cstr,
                  Backend backend,
                  DataLayout layout,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
151 152
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
153 154 155
    for (size_t dtype = static_cast<size_t>(DataType::BOOL);
         dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
         dtype++) {
156 157 158 159 160 161 162
      // NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
      // of these type.
      if (dtype == static_cast<size_t>(DataType::UINT32) ||
          dtype == static_cast<size_t>(DataType::UINT64) ||
          dtype == static_cast<size_t>(DataType::UINT16)) {
        continue;
      }
163 164 165 166 167 168
      ConstructKernel(kernel_name_cstr,
                      backend,
                      layout,
                      static_cast<DataType>(dtype),
                      args_parse_fn,
                      args_def_fn,
169 170
                      kernel_fn,
                      variadic_kernel_fn);
171 172 173 174 175 176 177 178 179 180
    }
  }

 private:
  void ConstructKernel(const char* kernel_name_cstr,
                       Backend backend,
                       DataLayout layout,
                       DataType dtype,
                       KernelArgsParseFn args_parse_fn,
                       KernelArgsDefFn args_def_fn,
181 182
                       KernelFn kernel_fn,
                       void* variadic_kernel_fn) {
Y
YuanRisheng 已提交
183
    std::string kernel_name(kernel_name_cstr);
184
    KernelKey kernel_key(backend, layout, dtype);
185
    Kernel kernel(kernel_fn, variadic_kernel_fn);
186
    args_parse_fn(kernel_key, kernel.mutable_args_def());
187
    args_def_fn(kernel_key, &kernel);
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
  }
};

/**
 * Reference:
 *
 *   https://stackoverflow.com/questions/1872220/is-it-possible-to-iterate-over-arguments-in-variadic-macros
 *   https://stackoverflow.com/questions/9183993/msvc-variadic-macro-expansion?rq=1
 *   https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
 *
 * Very carefully tiptoeing around an MSVC bug where it improperly expands
 * __VA_ARGS__ as a single token in argument lists.  See these URLs for details:
 *
 *   http://connect.microsoft.com/VisualStudio/feedback/details/380090/variadic-macro-replacement
 *   http://cplusplus.co.il/2010/07/17/variadic-macro-to-count-number-of-arguments/#comment-644
 */
#define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N()))
#define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__)
207 208 209
#define _PT_ARG_N_EXPAND(                                                     \
    _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
  N
210
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
211
#define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
212

213 214 215 216 217 218
/** PT_REGISTER_KERNEL
 *
 * The most frequently used kernel registration macro, used for kernel
 * registration with only data type as template parameter, and the function
 * pointer of the corresponding data type is automatically instantiated
 * during registration.
219
 *
220
 * Note: `2TA` means `2 template argument`
221
 */
222 223 224 225 226 227
#define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
  PT_STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
      pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout,    \
      "PT_REGISTER_KERNEL must be called in global namespace.");              \
  PT_EXPAND(_PT_REGISTER_2TA_KERNEL(                                          \
      kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__))
228

229
#ifndef _WIN32
230
#define _PT_REGISTER_2TA_KERNEL(                                            \
231 232
    kernel_name, backend, layout, meta_kernel_fn, ...)                      \
  PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__);            \
233
  static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
234
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
235 236 237 238 239 240 241 242
  PT_KERNEL_REGISTRAR_INIT(                                                 \
      kernel_name,                                                          \
      backend,                                                              \
      layout,                                                               \
      &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
      meta_kernel_fn,                                                       \
      __VA_ARGS__);                                                         \
  void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
243
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
244 245 246 247 248 249 250 251 252 253 254
#else
/**
 * `template decltype(fn) fn` can work on gcc and clang,
 * but msvc will failed, error like:
 *
 *   error C2206: typedef cannot be used for function definition
 *
 * reference:
 *
 *   https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
 *
255
 * And msvc can work without template instantiation
256
 */
257
#define _PT_REGISTER_2TA_KERNEL(                                            \
258
    kernel_name, backend, layout, meta_kernel_fn, ...)                      \
259
  static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
260
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
261
  PT_EXPAND(PT_KERNEL_REGISTRAR_INIT(                                       \
262 263 264 265 266
      kernel_name,                                                          \
      backend,                                                              \
      layout,                                                               \
      &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
      meta_kernel_fn,                                                       \
267
      __VA_ARGS__));                                                        \
268
  void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
269
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
270 271
#endif

272 273 274
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
  _PT_KERNEL_INSTANTIATION(                                   \
      PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__)
275

276 277 278
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \
  PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N)                    \
  (meta_kernel_fn, backend, __VA_ARGS__)
279

280 281 282
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype)  \
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>
283
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
284 285
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
286 287
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
288 289
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
290 291
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
292 293
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
294 295
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
296 297
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
298 299
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
300 301
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
302 303
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
304 305
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
306 307
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
308 309
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
310 311
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
312 313
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)     \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                   \
314 315
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
316 317
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
318 319
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
320 321
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
322 323
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
324 325
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
326 327
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
328 329
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
330 331
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
332 333
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
334 335
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
336 337
  template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)      \
      meta_kernel_fn<cpp_dtype, ::phi::backend##Context>;                    \
338
  PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
339

340 341 342 343 344 345 346 347 348
#define PT_KERNEL_REGISTRAR_INIT(                                   \
    kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__),        \
                                      kernel_name,                  \
                                      backend,                      \
                                      layout,                       \
                                      args_def_fn,                  \
                                      meta_kernel_fn,               \
                                      __VA_ARGS__))
349 350 351 352 353

// clang-format off

/* The =pre-commit always treats this macro into the wrong format,
  and multi-line macros cannot be skipped with NOLINT.*/
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
#define _PT_KERNEL_REGISTRAR_INIT(N,                       \
                                  kernel_name,             \
                                  backend,                 \
                                  layout,                  \
                                  args_def_fn,             \
                                  meta_kernel_fn,          \
                                  ...)                     \
  PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
    kernel_name,                                           \
    backend,                                               \
    layout,                                                \
    PT_ID,                                                 \
    args_def_fn,                                           \
    meta_kernel_fn,                                        \
    __VA_ARGS__))
369 370 371

// clang-format on

372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype)                                 \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
390
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
416
                                        __VA_ARGS__))
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
442
                                        __VA_ARGS__))
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
468
                                        __VA_ARGS__))
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
494
                                        __VA_ARGS__))
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
520
                                        __VA_ARGS__))
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
546
                                        __VA_ARGS__))
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
572
                                        __VA_ARGS__))
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name,                               \
                                    backend,                                   \
                                    layout,                                    \
                                    registrar_id,                              \
                                    args_def_fn,                               \
                                    meta_kernel_fn,                            \
                                    cpp_dtype,                                 \
                                    ...)                                       \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
598
                                        __VA_ARGS__))
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name,                           \
                                        backend,                               \
                                        layout,                                \
                                        PT_ID,                                 \
                                        args_def_fn,                           \
                                        meta_kernel_fn,                        \
624
                                        __VA_ARGS__))
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name,                          \
                                         backend,                              \
                                         layout,                               \
                                         PT_ID,                                \
                                         args_def_fn,                          \
                                         meta_kernel_fn,                       \
650
                                         __VA_ARGS__))
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name,                          \
                                         backend,                              \
                                         layout,                               \
                                         PT_ID,                                \
                                         args_def_fn,                          \
                                         meta_kernel_fn,                       \
676
                                         __VA_ARGS__))
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name,                          \
                                         backend,                              \
                                         layout,                               \
                                         PT_ID,                                \
                                         args_def_fn,                          \
                                         meta_kernel_fn,                       \
702
                                         __VA_ARGS__))
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name,                          \
                                         backend,                              \
                                         layout,                               \
                                         PT_ID,                                \
                                         args_def_fn,                          \
                                         meta_kernel_fn,                       \
728
                                         __VA_ARGS__))
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name,                              \
                                     backend,                                  \
                                     layout,                                   \
                                     registrar_id,                             \
                                     args_def_fn,                              \
                                     meta_kernel_fn,                           \
                                     cpp_dtype,                                \
                                     ...)                                      \
  static const ::phi::KernelRegistrar PT_CONCATENATE(                          \
      __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)(  \
      #kernel_name,                                                            \
      BACKEND(backend),                                                        \
      DATALAYOUT(layout),                                                      \
      ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),            \
      ::phi::KernelArgsParseFunctor<decltype(                                  \
          &meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse,        \
      args_def_fn,                                                             \
      PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>),           \
      PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \
  PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name,                          \
                                         backend,                              \
                                         layout,                               \
                                         PT_ID,                                \
                                         args_def_fn,                          \
                                         meta_kernel_fn,                       \
754
                                         __VA_ARGS__))
755

756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
/** PT_REGISTER_GENERAL_KERNEL
 *
 * Basic Kernel register marco, used to register a instantiated kernel function
 * with one template argument.
 */

#define PT_REGISTER_GENERAL_KERNEL(                                          \
    kernel_name, backend, layout, kernel_fn, dtype)                          \
  PT_STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
      pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
  _PT_REGISTER_GENERAL_KERNEL(kernel_name, backend, layout, kernel_fn, dtype)

#ifndef _WIN32
#define _PT_REGISTER_GENERAL_KERNEL(                                        \
    kernel_name, backend, layout, kernel_fn, dtype)                         \
  template decltype(kernel_fn) kernel_fn;                                   \
  static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
774 775
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
776 777 778 779
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
          #kernel_name,                                                     \
          BACKEND(backend),                                                 \
          DATALAYOUT(layout),                                               \
780
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
781 782 783 784 785 786 787
          &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PT_KERNEL(kernel_fn),                                             \
          PT_VARIADIC_KERNEL(kernel_fn));                                   \
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
  void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
788
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
789 790 791 792
#else
#define _PT_REGISTER_GENERAL_KERNEL(                                        \
    kernel_name, backend, layout, kernel_fn, dtype)                         \
  static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
793 794
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
  static const ::phi::KernelRegistrar                                       \
795 796 797 798
      __reg_pt_kernel_##kernel_name##_##backend##_##layout(                 \
          #kernel_name,                                                     \
          BACKEND(backend),                                                 \
          DATALAYOUT(layout),                                               \
799
          ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse,       \
800 801 802 803 804 805 806
          &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,    \
          PT_KERNEL(kernel_fn),                                             \
          PT_VARIADIC_KERNEL(kernel_fn));                                   \
  int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() {         \
    return 0;                                                               \
  }                                                                         \
  void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
807
      const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
808 809
#endif

810 811 812 813 814
/** PT_DECLARE_KERNEL
 *
 * Used to export the symbols of the file where the kernel is located,
 * to avoid being removed by linker
 */
815
#define PT_DECLARE_KERNEL(kernel_name, backend, layout)                   \
816 817 818
  PT_STATIC_ASSERT_GLOBAL_NAMESPACE(                                      \
      pt_declare_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PT_DECLARE_KERNEL must be called in global namespace.");           \
819 820 821 822
  extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
  UNUSED static int                                                       \
      __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout =  \
          TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
823

824
}  // namespace phi