op_registry.h 18.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#pragma once

17
#include <algorithm>
18
#include <atomic>
19
#include <memory>
Y
Yang Yang 已提交
20 21
#include <string>
#include <tuple>
Y
Yu Yang 已提交
22
#include <type_traits>
F
WIP  
fengjiayi 已提交
23
#include <typeinfo>
24 25
#include <unordered_map>
#include <unordered_set>
Y
Yu Yang 已提交
26

P
peizhilin 已提交
27
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
28 29
#include "gflags/gflags.h"
#include "glog/logging.h"  // For VLOG()
Y
Yi Wang 已提交
30 31 32 33 34 35 36
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/details/op_registry.h"
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h"
37

W
wanghuancoder 已提交
38 39 40 41 42 43
namespace paddle {
namespace framework {
class ExecutionContext;
}  // namespace framework
}  // namespace paddle

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
namespace paddle {
namespace framework {
namespace proto {

class BlockDesc;
class OpDesc;
class OpDesc_Attr;
class OpDesc_Var;
class OpProto;
class OpProto_Attr;
class OpProto_Var;
class OpVersion;
class OpVersionMap;
class OpVersionMap_OpVersionPair;
class ProgramDesc;
class VarDesc;
class VarType;
class VarType_LoDTensorArrayDesc;
class VarType_LoDTensorDesc;
class VarType_ReaderDesc;
class VarType_TensorDesc;
class VarType_Tuple;
class Version;
}  // namespace proto
}  // namespace framework
}  // namespace paddle

71 72
DECLARE_bool(check_kernel_launch);

73 74
namespace paddle {
namespace framework {
X
Xin Pan 已提交
75

Y
Yu Yang 已提交
76 77 78 79
class Registrar {
 public:
  // In our design, various kinds of classes, e.g., operators and kernels,
  // have their corresponding registry and registrar. The action of
80 81
  // registration is in the constructor of a global registrar variable, which
  // are not used in the code that calls package framework, and would
Y
Yu Yang 已提交
82 83 84 85 86 87
  // be removed from the generated binary file by the linker. To avoid such
  // removal, we add Touch to all registrar classes and make USE_OP macros to
  // call this method. So, as long as the callee code calls USE_OP, the global
  // registrar variable won't be removed by the linker.
  void Touch() {}
};
88

89
template <typename... ARGS>
Y
Yu Yang 已提交
90
struct OperatorRegistrar : public Registrar {
91
  explicit OperatorRegistrar(const char* op_type) {
92
    PADDLE_ENFORCE_EQ(
93 94
        OpInfoMap::Instance().Has(op_type),
        false,
95 96
        platform::errors::AlreadyExists(
            "Operator '%s' is registered more than once.", op_type));
97 98
    static_assert(sizeof...(ARGS) != 0,
                  "OperatorRegistrar should be invoked at least by OpClass");
99
    OpInfo info;
100
    details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
Y
Yu Yang 已提交
101
    OpInfoMap::Instance().Insert(op_type, info);
102 103 104
  }
};

105 106
class OpRegistry {
 public:
H
hong 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
  /**
   * @brief Return an OperatorBase constructed by type, inputs, outputs, attrs.
   *        In dygraph mode, inputs, output, attrs will be set to empty map to
   *        improve the execution efficiency of dygraph.
   *        Dygraph mode will use:
   *        framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
   *
   * @param[str] type               The operator type.
   * @param[map] inputs             Inputs map of the operator.
   * @param[map] outputs            Outputs map of the operator.
   * @param[unordered_map] attrs    Attributes map of the operator.
   * @param[bool] attr_check
   *            Whether do the attribute check before OperatorBase construction.
   *            Default is true.
   *            Attr_check is used to control the check of attribute map.
   *            The check of attribute map have two purposes:
   *            1. check whether the attribute item is valid or not.
   *            2. add attribute item which has default value
   *            if it is not in attrs.
   *            In dygraph mode, attrs is an empty unordered_map,
   *            attr_check is set to false, otherwise it will be failed
   *            when check function called.
   */
Y
Yu Yang 已提交
130
  static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
Y
Yu Yang 已提交
131 132
                                                const VariableNameMap& inputs,
                                                const VariableNameMap& outputs,
133
                                                const AttributeMap& attrs,
H
hong 已提交
134
                                                bool attr_check = true);
Y
Yu Yang 已提交
135

136
  static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
Y
Yu Yang 已提交
137

Y
Yu Yang 已提交
138
  static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
F
Fix bug  
fengjiayi 已提交
139
};
F
fengjiayi 已提交
140

141
template <typename PlaceType>
142
inline void CheckKernelLaunch(const char* op_type) {}
143 144 145 146 147

#ifdef PADDLE_WITH_CUDA
template <>
inline void CheckKernelLaunch<::paddle::platform::CUDAPlace>(
    const char* op_type) {
148 149 150 151
  if (FLAGS_check_kernel_launch) {
    PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(op_type);
  }
}
152 153
#endif

154 155 156
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;

157
template <typename PlaceType, typename T, typename Func>
158 159 160 161
inline void RegisterKernelClass(const char* op_type,
                                const char* library_type,
                                int customized_type_value,
                                Func func) {
Y
yuyang18 已提交
162 163 164 165 166
  std::string library(library_type);
  std::string data_layout = "ANYLAYOUT";
  if (library == "MKLDNN") {
    data_layout = "MKLDNNLAYOUT";
  }
167 168
  OpKernelType key(ToDataType(std::type_index(typeid(T))),
                   PlaceType(),
Y
yuyang18 已提交
169
                   StringToDataLayout(data_layout),
170 171
                   StringToLibraryType(library_type),
                   customized_type_value);
172
  OperatorWithKernel::AllOpKernels()[op_type][key] = func;
Y
yuyang18 已提交
173 174
}

175 176 177 178
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
  using KERNEL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
179

180 181
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
182
                  int customized_type_value) const {
183
    using T = typename KERNEL_TYPE::ELEMENT_TYPE;
184
    RegisterKernelClass<PlaceType, T>(
185 186 187
        op_type,
        library_type,
        customized_type_value,
X
Xin Pan 已提交
188

189
        [op_type](const framework::ExecutionContext& ctx) {
Y
yuyang18 已提交
190
          KERNEL_TYPE().Compute(ctx);
191
          CheckKernelLaunch<PlaceType>(op_type);
192
        });
193 194
    constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
    OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
195
        func;
X
Xin Pan 已提交
196
    func(op_type, library_type, customized_type_value);
197 198 199 200 201
  }
};

template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
202 203
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
204
                  int customized_type_value) const {}
205 206
};

M
mozga-intel 已提交
207 208
// User can register many kernel in one place. The data type could be
// different.
209
template <typename PlaceType, typename... KernelType>
F
fengjiayi 已提交
210 211
class OpKernelRegistrar : public Registrar {
 public:
212 213
  explicit OpKernelRegistrar(const char* op_type,
                             const char* library_type,
X
Xin Pan 已提交
214
                             int customized_type_value) {
215
    OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
X
Xin Pan 已提交
216
    func(op_type, library_type, customized_type_value);
F
fengjiayi 已提交
217 218 219
  }
};

Y
yuyang18 已提交
220 221 222 223 224 225
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctorEx;

template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar {
 public:
226 227
  explicit OpKernelRegistrarEx(const char* op_type,
                               const char* library_type,
X
Xin Pan 已提交
228
                               int customized_type_value) {
Y
yuyang18 已提交
229 230
    OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
        func;
X
Xin Pan 已提交
231
    func(op_type, library_type, customized_type_value);
Y
yuyang18 已提交
232 233 234 235
  }
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
236 237 238
struct OpKernelRegistrarFunctorEx<PlaceType,
                                  true,
                                  I,
Y
yuyang18 已提交
239
                                  DataTypeAndKernelType...> {
240 241
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
242
                  int customized_type_value) const {}
Y
yuyang18 已提交
243 244 245
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
246 247 248
struct OpKernelRegistrarFunctorEx<PlaceType,
                                  false,
                                  I,
Y
yuyang18 已提交
249
                                  DataTypeAndKernelType...> {
250
  using Functor =
Y
yuyang18 已提交
251 252 253 254 255 256
      typename std::tuple_element<I + 1,
                                  std::tuple<DataTypeAndKernelType...>>::type;
  using T =
      typename std::tuple_element<I,
                                  std::tuple<DataTypeAndKernelType...>>::type;

257 258
  void operator()(const char* op_type,
                  const char* library_type,
X
Xin Pan 已提交
259
                  int customized_type_value) const {
260
    RegisterKernelClass<PlaceType, T>(
261 262 263
        op_type,
        library_type,
        customized_type_value,
264 265 266 267 268

        [op_type](const framework::ExecutionContext& ctx) {
          Functor()(ctx);
          CheckKernelLaunch<PlaceType>(op_type);
        });
Y
yuyang18 已提交
269 270 271

    constexpr auto size =
        std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
272 273 274
    OpKernelRegistrarFunctorEx<PlaceType,
                               I + 2 >= size,
                               I + 2,
Y
yuyang18 已提交
275 276
                               DataTypeAndKernelType...>
        func;
X
Xin Pan 已提交
277
    func(op_type, library_type, customized_type_value);
Y
yuyang18 已提交
278 279 280
  }
};

X
Xin Pan 已提交
281
// clang-format off
282 283 284
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
285 286 287 288 289 290
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

291 292 293 294 295 296 297 298
/*
  The variadic arguments should be class types derived from one of the
  following classes:
    OpProtoAndCheckerMaker
    GradOpDescMakerBase
    VarTypeInference
    InferShapeBase
*/
Y
yuyang18 已提交
299 300 301 302 303 304 305 306 307
#define REGISTER_OPERATOR(op_type, op_class, ...)                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
      __reg_op__##op_type,                                               \
      "REGISTER_OPERATOR must be called in global namespace");           \
  static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
      __op_registrar_##op_type##__(#op_type);                            \
  int TouchOpRegistrar_##op_type() {                                     \
    __op_registrar_##op_type##__.Touch();                                \
    return 0;                                                            \
Y
Yu Yang 已提交
308 309
  }

310 311
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...) \
  REGISTER_OPERATOR(op_type, op_class, __VA_ARGS__, \
H
hong 已提交
312 313
        paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,   \
        paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
D
dongzhihong 已提交
314

D
dongzhihong 已提交
315
/**
316
 * Macro to register OperatorKernel.
D
dongzhihong 已提交
317
 */
X
Xin Pan 已提交
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type,             \
                                            place_class, customized_name,      \
                                            customized_type_value, ...)        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__,      \
                                 "REGISTER_OP_KERNEL must be called in "       \
                                 "global namespace");                          \
  static ::paddle::framework::OpKernelRegistrar<place_class,                   \
                                                __VA_ARGS__>                   \
      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
          #op_type, #library_type, customized_type_value);                     \
  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
        .Touch();                                                              \
    return 0;                                                                  \
F
fengjiayi 已提交
333
  }
D
dongzhihong 已提交
334

X
Xin Pan 已提交
335 336 337 338 339 340
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)   \
  REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(                                \
      op_type, library_type, place_class, DEFAULT_TYPE,               \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

341
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
QI JUN 已提交
342
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
D
dzhwinter 已提交
343
  REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
344 345 346
#else
#define REGISTER_OP_CUDA_KERNEL(op_type, ...)
#endif
F
fengjiayi 已提交
347

F
fengjiayi 已提交
348 349
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
350

J
jianghaicheng 已提交
351 352 353
#define REGISTER_OP_IPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, IPU, ::paddle::platform::IPUPlace, __VA_ARGS__)

354 355 356
#define REGISTER_OP_XPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, XPU, ::paddle::platform::XPUPlace, __VA_ARGS__)

357 358 359
#define REGISTER_OP_NPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, NPU, ::paddle::platform::NPUPlace, __VA_ARGS__)

F
fwenguang 已提交
360 361 362
#define REGISTER_OP_MLU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, MLU, ::paddle::platform::MLUPlace, __VA_ARGS__)

X
Xin Pan 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class,  \
                              customized_name,                     \
                              customized_type_value,               \
                              ...)                                 \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                  \
      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
                                 "REGISTER_OP_KERNEL_EX must be called in "  \
                                 "global namespace");  \
  static ::paddle::framework::OpKernelRegistrarEx<place_class,  \
                                                  __VA_ARGS__>  \
      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
          #op_type, #library_type, customized_type_value);  \
  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
        .Touch();                                                              \
    return 0;                                                                  \
Y
yuyang18 已提交
379 380
  }

381
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)                 \
X
Xin Pan 已提交
382 383 384 385
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE,     \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)
Y
yuyang18 已提交
386

X
Xin Pan 已提交
387 388 389 390 391
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)
Y
yuyang18 已提交
392

393 394 395 396 397 398
#define REGISTER_OP_XPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, XPU, ::paddle::platform::XPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

399 400 401 402 403 404
#define REGISTER_OP_NPU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, NPU, ::paddle::platform::NPUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

F
fwenguang 已提交
405 406 407 408 409 410
#define REGISTER_OP_MLU_KERNEL_FUNCTOR(op_type, ...)                  \
  REGISTER_OP_KERNEL_EX(                                              \
      op_type, MLU, ::paddle::platform::MLUPlace, DEFAULT_TYPE,       \
      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
      __VA_ARGS__)

411
/**
412 413
 * Macro to mark what Operator and Kernel
 * we will use and tell the compiler to
414 415
 * link them into target.
 */
D
dzhwinter 已提交
416 417 418 419 420 421
#define USE_OP_ITSELF(op_type)                             \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                          \
      __use_op_itself_##op_type,                           \
      "USE_OP_ITSELF must be called in global namespace"); \
  extern int TouchOpRegistrar_##op_type();                 \
  UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
F
fengjiayi 已提交
422

X
Xin Pan 已提交
423 424 425 426 427 428 429 430
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type,                     \
                                              LIBRARY_TYPE,                \
                                              customized_name)             \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
      __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__,  \
      "USE_OP_DEVICE_KERNEL must be in global namespace");                 \
  extern int                                                               \
      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
431
  UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##_ = /* NOLINT */ \
X
Xin Pan 已提交
432 433 434 435
      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()

#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
  USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
Y
Yu Yang 已提交
436

437 438
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
Y
Yu Yang 已提交
439

440
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
441
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
442
#else
443 444
#define USE_OP_KERNEL(op_type)        \
  USE_OP_DEVICE_KERNEL(op_type, CPU); \
Q
QI JUN 已提交
445
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
Y
Yu Yang 已提交
446
#endif
447

448 449
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);

F
WIP  
fengjiayi 已提交
450 451 452
#define USE_CPU_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);        \
  USE_OP_DEVICE_KERNEL(op_type, CPU);
453

Q
QI JUN 已提交
454 455 456
#define USE_CUDA_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);         \
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
D
Dong Zhihong 已提交
457

F
WIP  
fengjiayi 已提交
458 459 460
#define USE_OP(op_type)   \
  USE_OP_ITSELF(op_type); \
  USE_OP_KERNEL(op_type)
X
Xin Pan 已提交
461
// clang-format on
462

463 464
}  // namespace framework
}  // namespace paddle