op_registry.h 21.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <memory>
Y
Yang Yang 已提交
20 21
#include <string>
#include <tuple>
Y
Yu Yang 已提交
22
#include <type_traits>
F
WIP  
fengjiayi 已提交
23
#include <typeinfo>
24 25
#include <unordered_map>
#include <unordered_set>
Y
Yu Yang 已提交
26

P
peizhilin 已提交
27
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
28
#include "glog/logging.h"               // For VLOG()
Y
Yi Wang 已提交
29 30 31 32 33 34 35
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/details/op_registry.h"
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h"
36
#include "paddle/phi/core/flags.h"
37
#include "paddle/phi/core/kernel_registry.h"
38
#include "paddle/phi/core/macros.h"
39

W
wanghuancoder 已提交
40 41 42 43 44 45
namespace paddle {
namespace framework {
class ExecutionContext;
}  // namespace framework
}  // namespace paddle

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
namespace paddle {
namespace framework {
namespace proto {

class BlockDesc;
class OpDesc;
class OpDesc_Attr;
class OpDesc_Var;
class OpProto;
class OpProto_Attr;
class OpProto_Var;
class OpVersion;
class OpVersionMap;
class OpVersionMap_OpVersionPair;
class ProgramDesc;
class VarDesc;
class VarType;
class VarType_LoDTensorArrayDesc;
class VarType_LoDTensorDesc;
class VarType_ReaderDesc;
class VarType_TensorDesc;
class VarType_Tuple;
class Version;
}  // namespace proto
}  // namespace framework
}  // namespace paddle

73
PHI_DECLARE_bool(check_kernel_launch);
74

75 76
namespace paddle {
namespace framework {
X
Xin Pan 已提交
77

Y
Yu Yang 已提交
78 79 80 81
class Registrar {
 public:
  // In our design, various kinds of classes, e.g., operators and kernels,
  // have their corresponding registry and registrar. The action of
82 83
  // registration is in the constructor of a global registrar variable, which
  // are not used in the code that calls package framework, and would
Y
Yu Yang 已提交
84 85 86 87 88 89
  // be removed from the generated binary file by the linker. To avoid such
  // removal, we add Touch to all registrar classes and make USE_OP macros to
  // call this method. So, as long as the callee code calls USE_OP, the global
  // registrar variable won't be removed by the linker.
  void Touch() {}
};
90

91
template <typename... ARGS>
Y
Yu Yang 已提交
92
struct OperatorRegistrar : public Registrar {
93
  explicit OperatorRegistrar(const char* op_type) {
94
    PADDLE_ENFORCE_EQ(
95 96
        OpInfoMap::Instance().Has(op_type),
        false,
97 98
        platform::errors::AlreadyExists(
            "Operator '%s' is registered more than once.", op_type));
99 100
    static_assert(sizeof...(ARGS) != 0,
                  "OperatorRegistrar should be invoked at least by OpClass");
101
    OpInfo info;
102
    details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
Y
Yu Yang 已提交
103
    OpInfoMap::Instance().Insert(op_type, info);
104 105 106
  }
};

107 108
class OpRegistry {
 public:
H
hong 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  /**
   * @brief Return an OperatorBase constructed by type, inputs, outputs, attrs.
   *        In dygraph mode, inputs, output, attrs will be set to empty map to
   *        improve the execution efficiency of dygraph.
   *        Dygraph mode will use:
   *        framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
   *
   * @param[str] type               The operator type.
   * @param[map] inputs             Inputs map of the operator.
   * @param[map] outputs            Outputs map of the operator.
   * @param[unordered_map] attrs    Attributes map of the operator.
   * @param[bool] attr_check
   *            Whether do the attribute check before OperatorBase construction.
   *            Default is true.
   *            Attr_check is used to control the check of attribute map.
   *            The check of attribute map have two purposes:
   *            1. check whether the attribute item is valid or not.
   *            2. add attribute item which has default value
   *            if it is not in attrs.
   *            In dygraph mode, attrs is an empty unordered_map,
   *            attr_check is set to false, otherwise it will be failed
   *            when check function called.
   */
Y
Yu Yang 已提交
132
  static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
Y
Yu Yang 已提交
133 134
                                                const VariableNameMap& inputs,
                                                const VariableNameMap& outputs,
135
                                                const AttributeMap& attrs,
H
hong 已提交
136
                                                bool attr_check = true);
137 138 139 140 141 142 143
  static std::unique_ptr<OperatorBase> CreateOp(
      const std::string& type,
      const VariableNameMap& inputs,
      const VariableNameMap& outputs,
      const AttributeMap& attrs,
      const AttributeMap& runtime_attrs,
      bool attr_check = true);
Y
Yu Yang 已提交
144

145
  static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
Y
Yu Yang 已提交
146

Y
Yu Yang 已提交
147
  static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
F
Fix bug  
fengjiayi 已提交
148
};
F
fengjiayi 已提交
149

150
template <typename PlaceType>
151
inline void CheckKernelLaunch(const char* op_type UNUSED) {}
152 153 154 155 156

#ifdef PADDLE_WITH_CUDA
template <>
inline void CheckKernelLaunch<::paddle::platform::CUDAPlace>(
    const char* op_type) {
157 158 159 160
  if (FLAGS_check_kernel_launch) {
    PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(op_type);
  }
}
161 162
#endif

163 164 165
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;

166
template <typename PlaceType, typename T, typename Func>
167 168 169 170
inline void RegisterKernelClass(const char* op_type,
                                const char* library_type,
                                int customized_type_value,
                                Func func) {
Y
yuyang18 已提交
171 172 173 174 175
  std::string library(library_type);
  std::string data_layout = "ANYLAYOUT";
  if (library == "MKLDNN") {
    data_layout = "MKLDNNLAYOUT";
  }
176 177 178 179
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (std::is_same<PlaceType, platform::CustomPlace>::value) {
    OpKernelType key(ToDataType(std::type_index(typeid(T))),
                     platform::CustomPlace(library_type),
180
                     phi::StringToDataLayout(data_layout),
181 182 183 184 185 186
                     LibraryType::kPlain,
                     customized_type_value);
    OperatorWithKernel::AllOpKernels()[op_type][key] = func;
    return;
  }
#endif
187 188
  OpKernelType key(ToDataType(std::type_index(typeid(T))),
                   PlaceType(),
189
                   phi::StringToDataLayout(data_layout),
190 191
                   StringToLibraryType(library_type),
                   customized_type_value);
192
  OperatorWithKernel::AllOpKernels()[op_type][key] = func;
Y
yuyang18 已提交
193 194
}

195 196 197 198
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
  using KERNEL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
199

200 201
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
202
                  int customized_type_value) const {
203
    using T = typename KERNEL_TYPE::ELEMENT_TYPE;
204
    RegisterKernelClass<PlaceType, T>(
205 206 207
        op_type,
        library_type,
        customized_type_value,
X
Xin Pan 已提交
208

209
        [op_type](const framework::ExecutionContext& ctx) {
Y
yuyang18 已提交
210
          KERNEL_TYPE().Compute(ctx);
211
          CheckKernelLaunch<PlaceType>(op_type);
212
        });
213 214
    constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
    OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
215
        func;
X
Xin Pan 已提交
216
    func(op_type, library_type, customized_type_value);
217 218 219 220 221
  }
};

template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
222 223 224
  void operator()(const char* op_type UNUSED,
                  const char* library_type UNUSED,
                  int customized_type_value UNUSED) const {}
225 226
};

M
mozga-intel 已提交
227 228
// User can register many kernel in one place. The data type could be
// different.
229
template <typename PlaceType, typename... KernelType>
F
fengjiayi 已提交
230 231
class OpKernelRegistrar : public Registrar {
 public:
232 233
  explicit OpKernelRegistrar(const char* op_type,
                             const char* library_type,
X
Xin Pan 已提交
234
                             int customized_type_value) {
235
    OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
X
Xin Pan 已提交
236
    func(op_type, library_type, customized_type_value);
F
fengjiayi 已提交
237 238 239
  }
};

Y
yuyang18 已提交
240 241 242 243 244 245
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctorEx;

template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar {
 public:
246 247
  explicit OpKernelRegistrarEx(const char* op_type,
                               const char* library_type,
X
Xin Pan 已提交
248
                               int customized_type_value) {
Y
yuyang18 已提交
249 250
    OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
        func;
X
Xin Pan 已提交
251
    func(op_type, library_type, customized_type_value);
Y
yuyang18 已提交
252 253 254 255
  }
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
256 257 258
struct OpKernelRegistrarFunctorEx<PlaceType,
                                  true,
                                  I,
Y
yuyang18 已提交
259
                                  DataTypeAndKernelType...> {
260 261
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
262
                  int customized_type_value) const {}
Y
yuyang18 已提交
263 264 265
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
266 267 268
struct OpKernelRegistrarFunctorEx<PlaceType,
                                  false,
                                  I,
Y
yuyang18 已提交
269
                                  DataTypeAndKernelType...> {
270
  using Functor =
Y
yuyang18 已提交
271 272 273 274 275 276
      typename std::tuple_element<I + 1,
                                  std::tuple<DataTypeAndKernelType...>>::type;
  using T =
      typename std::tuple_element<I,
                                  std::tuple<DataTypeAndKernelType...>>::type;

277 278
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
279
                  int customized_type_value) const {
280
    RegisterKernelClass<PlaceType, T>(
281 282 283
        op_type,
        library_type,
        customized_type_value,
284 285 286 287 288

        [op_type](const framework::ExecutionContext& ctx) {
          Functor()(ctx);
          CheckKernelLaunch<PlaceType>(op_type);
        });
Y
yuyang18 已提交
289 290 291

    constexpr auto size =
        std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
292 293 294
    OpKernelRegistrarFunctorEx<PlaceType,
                               I + 2 >= size,
                               I + 2,
Y
yuyang18 已提交
295 296
                               DataTypeAndKernelType...>
        func;
X
Xin Pan 已提交
297
    func(op_type, library_type, customized_type_value);
Y
yuyang18 已提交
298 299 300
  }
};

X
Xin Pan 已提交
301
// clang-format off
302 303 304
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
305 306 307 308 309 310
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

311 312 313 314 315 316 317 318
/*
  The variadic arguments should be class types derived from one of the
  following classes:
    OpProtoAndCheckerMaker
    GradOpDescMakerBase
    VarTypeInference
    InferShapeBase
*/
Y
yuyang18 已提交
319 320 321 322 323 324 325 326 327
#define REGISTER_OPERATOR(op_type, op_class, ...)                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
      __reg_op__##op_type,                                               \
      "REGISTER_OPERATOR must be called in global namespace");           \
  static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
      __op_registrar_##op_type##__(#op_type);                            \
  int TouchOpRegistrar_##op_type() {                                     \
    __op_registrar_##op_type##__.Touch();                                \
    return 0;                                                            \
Y
Yu Yang 已提交
328 329
  }

330 331
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...) \
  REGISTER_OPERATOR(op_type, op_class, __VA_ARGS__, \
H
hong 已提交
332 333
        paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,   \
        paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
D
dongzhihong 已提交
334

D
dongzhihong 已提交
335
/**
336
 * Macro to register OperatorKernel.
D
dongzhihong 已提交
337
 */
X
Xin Pan 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type,             \
                                            place_class, customized_name,      \
                                            customized_type_value, ...)        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__,      \
                                 "REGISTER_OP_KERNEL must be called in "       \
                                 "global namespace");                          \
  static ::paddle::framework::OpKernelRegistrar<place_class,                   \
                                                __VA_ARGS__>                   \
      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
          #op_type, #library_type, customized_type_value);                     \
  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
        .Touch();                                                              \
    return 0;                                                                  \
F
fengjiayi 已提交
353
  }
D
dongzhihong 已提交
354

X
Xin Pan 已提交
355 356 357 358 359 360
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)   \
  REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(                                \
      op_type, library_type, place_class, DEFAULT_TYPE,               \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

361
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
QI JUN 已提交
362
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
D
dzhwinter 已提交
363
  REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
364 365 366
#else
#define REGISTER_OP_CUDA_KERNEL(op_type, ...)
#endif
F
fengjiayi 已提交
367

F
fengjiayi 已提交
368 369
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
370

J
jianghaicheng 已提交
371 372 373
#define REGISTER_OP_IPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, IPU, ::paddle::platform::IPUPlace, __VA_ARGS__)

374 375 376
#define REGISTER_OP_XPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, XPU, ::paddle::platform::XPUPlace, __VA_ARGS__)

377 378 379
#define REGISTER_OP_NPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, NPU, ::paddle::platform::NPUPlace, __VA_ARGS__)

X
Xin Pan 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class,  \
                              customized_name,                     \
                              customized_type_value,               \
                              ...)                                 \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                  \
      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
                                 "REGISTER_OP_KERNEL_EX must be called in "  \
                                 "global namespace");  \
  static ::paddle::framework::OpKernelRegistrarEx<place_class,  \
                                                  __VA_ARGS__>  \
      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
          #op_type, #library_type, customized_type_value);  \
  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
        .Touch();                                                              \
    return 0;                                                                  \
Y
yuyang18 已提交
396 397
  }

398
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)                 \
X
Xin Pan 已提交
399 400 401 402
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE,     \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)
Y
yuyang18 已提交
403

X
Xin Pan 已提交
404 405 406 407 408
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)
Y
yuyang18 已提交
409

410 411 412 413 414 415
#define REGISTER_OP_XPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, XPU, ::paddle::platform::XPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

416 417 418 419 420 421
#define REGISTER_OP_NPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, NPU, ::paddle::platform::NPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

422 423 424 425 426 427
#define REGISTER_OP_IPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, IPU, ::paddle::platform::IPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

428
/**
429 430
 * Macro to mark what Operator and Kernel
 * we will use and tell the compiler to
431 432
 * link them into target.
 */
D
dzhwinter 已提交
433 434 435 436 437 438
#define USE_OP_ITSELF(op_type)                             \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                          \
      __use_op_itself_##op_type,                           \
      "USE_OP_ITSELF must be called in global namespace"); \
  extern int TouchOpRegistrar_##op_type();                 \
  UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
F
fengjiayi 已提交
439

X
Xin Pan 已提交
440 441 442 443 444 445 446 447
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type,                     \
                                              LIBRARY_TYPE,                \
                                              customized_name)             \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
      __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__,  \
      "USE_OP_DEVICE_KERNEL must be in global namespace");                 \
  extern int                                                               \
      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
448
  UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##_ = /* NOLINT */ \
X
Xin Pan 已提交
449 450 451 452
      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()

#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
  USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
Y
Yu Yang 已提交
453

454 455
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
Y
Yu Yang 已提交
456

457
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
458
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
459
#else
460 461
#define USE_OP_KERNEL(op_type)        \
  USE_OP_DEVICE_KERNEL(op_type, CPU); \
Q
QI JUN 已提交
462
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
Y
Yu Yang 已提交
463
#endif
464

465 466
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);

F
WIP  
fengjiayi 已提交
467 468 469
#define USE_CPU_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);        \
  USE_OP_DEVICE_KERNEL(op_type, CPU);
470

Q
QI JUN 已提交
471 472 473
#define USE_CUDA_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);         \
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
D
Dong Zhihong 已提交
474

F
WIP  
fengjiayi 已提交
475 476 477
#define USE_OP(op_type)   \
  USE_OP_ITSELF(op_type); \
  USE_OP_KERNEL(op_type)
X
Xin Pan 已提交
478
// clang-format on
479

480 481 482
template <typename StructureKernel, typename enable = void>
struct StructKernelImpl;

483
template <typename StructureKernel>
484 485 486 487
struct StructKernelImpl<
    StructureKernel,
    typename std::enable_if<std::is_base_of<paddle::framework::OpKernelBase,
                                            StructureKernel>::value>::type> {
488 489 490 491 492 493
  static void Compute(phi::KernelContext* ctx) {
    auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
    StructureKernel().Compute(*exe_ctx);
  }
};

494 495 496 497 498 499 500 501 502 503 504
template <typename StructureKernel>
struct StructKernelImpl<
    StructureKernel,
    typename std::enable_if<!std::is_base_of<paddle::framework::OpKernelBase,
                                             StructureKernel>::value>::type> {
  static void Compute(phi::KernelContext* ctx) {
    auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
    StructureKernel()(*exe_ctx);
  }
};

505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
#define PHI_STRUCTURE_KERNEL(...) \
  ::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute
#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr
#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr

#define STRUCTURE_KERNEL_INSTANTIATION(        \
    meta_kernel_structure, cpp_dtype, context) \
  template class meta_kernel_structure<cpp_dtype, context>;

#define PD_REGISTER_STRUCT_KERNEL(                            \
    kernel_name, backend, layout, meta_kernel_structure, ...) \
  _PD_REGISTER_KERNEL(::phi::RegType::INNER,                  \
                      kernel_name,                            \
                      backend,                                \
                      ::phi::backend##Context,                \
                      layout,                                 \
                      meta_kernel_structure,                  \
                      STRUCTURE_KERNEL_INSTANTIATION,         \
                      STRUCTURE_ARG_PARSE_FUNCTOR,            \
                      PHI_STRUCTURE_KERNEL,                   \
                      PHI_STRUCTURE_VARIADIC_KERNEL,          \
                      __VA_ARGS__)

528 529
}  // namespace framework
}  // namespace paddle