op_registry.h 12.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#pragma once

17
#include <algorithm>
18
#include <atomic>
Y
Yang Yang 已提交
19 20
#include <string>
#include <tuple>
Y
Yu Yang 已提交
21
#include <type_traits>
F
WIP  
fengjiayi 已提交
22
#include <typeinfo>
23 24
#include <unordered_map>
#include <unordered_set>
Y
Yu Yang 已提交
25 26

#include "glog/logging.h"  // For VLOG()
Y
Yi Wang 已提交
27 28 29 30 31 32 33 34
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/details/op_registry.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h"
35 36 37

namespace paddle {
namespace framework {
Y
Yu Yang 已提交
38 39 40 41
class Registrar {
 public:
  // In our design, various kinds of classes, e.g., operators and kernels,
  // have their corresponding registry and registrar. The action of
42 43
  // registration is in the constructor of a global registrar variable, which
  // are not used in the code that calls package framework, and would
Y
Yu Yang 已提交
44 45 46 47 48 49
  // be removed from the generated binary file by the linker. To avoid such
  // removal, we add Touch to all registrar classes and make USE_OP macros to
  // call this method. So, as long as the callee code calls USE_OP, the global
  // registrar variable won't be removed by the linker.
  void Touch() {}
};
50

51
template <typename... ARGS>
Y
Yu Yang 已提交
52
struct OperatorRegistrar : public Registrar {
53
  explicit OperatorRegistrar(const char* op_type) {
54 55 56 57
    PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
                   "'%s' is registered more than once.", op_type);
    static_assert(sizeof...(ARGS) != 0,
                  "OperatorRegistrar should be invoked at least by OpClass");
58
    OpInfo info;
59
    details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
Y
Yu Yang 已提交
60
    OpInfoMap::Instance().Insert(op_type, info);
61 62 63
  }
};

64 65
class OpRegistry {
 public:
Y
Yu Yang 已提交
66
  static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
Y
Yu Yang 已提交
67 68
                                                const VariableNameMap& inputs,
                                                const VariableNameMap& outputs,
69
                                                AttributeMap attrs);
Y
Yu Yang 已提交
70

71
  static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
Y
Yu Yang 已提交
72

Y
Yu Yang 已提交
73
  static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
F
Fix bug  
fengjiayi 已提交
74
};
F
fengjiayi 已提交
75

76 77 78
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;

79 80 81
template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type,
                                Func func) {
Y
yuyang18 已提交
82 83 84 85 86 87 88 89
  std::string library(library_type);
  std::string data_layout = "ANYLAYOUT";
  if (library == "MKLDNN") {
    data_layout = "MKLDNNLAYOUT";
  }
  OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
                   StringToDataLayout(data_layout),
                   StringToLibraryType(library_type));
90
  OperatorWithKernel::AllOpKernels()[op_type][key] = func;
Y
yuyang18 已提交
91 92
}

93 94 95 96
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
  using KERNEL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
97

D
dzhwinter 已提交
98
  void operator()(const char* op_type, const char* library_type) const {
99
    using T = typename KERNEL_TYPE::ELEMENT_TYPE;
100 101
    RegisterKernelClass<PlaceType, T>(
        op_type, library_type, [](const framework::ExecutionContext& ctx) {
Y
yuyang18 已提交
102
          KERNEL_TYPE().Compute(ctx);
103
        });
104 105
    constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
    OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
106
        func;
D
dzhwinter 已提交
107
    func(op_type, library_type);
108 109 110 111 112
  }
};

template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
D
dzhwinter 已提交
113
  void operator()(const char* op_type, const char* library_type) const {}
114 115
};

M
mozga-intel 已提交
116 117
// User can register many kernel in one place. The data type could be
// different.
118
template <typename PlaceType, typename... KernelType>
F
fengjiayi 已提交
119 120
class OpKernelRegistrar : public Registrar {
 public:
D
dzhwinter 已提交
121
  explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
122
    OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
D
dzhwinter 已提交
123
    func(op_type, library_type);
F
fengjiayi 已提交
124 125 126
  }
};

Y
yuyang18 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctorEx;

template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar {
 public:
  explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
    OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
        func;
    func(op_type, library_type);
  }
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
                                  DataTypeAndKernelType...> {
  void operator()(const char* op_type, const char* library_type) const {}
};

template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
                                  DataTypeAndKernelType...> {
149
  using Functor =
Y
yuyang18 已提交
150 151 152 153 154 155 156
      typename std::tuple_element<I + 1,
                                  std::tuple<DataTypeAndKernelType...>>::type;
  using T =
      typename std::tuple_element<I,
                                  std::tuple<DataTypeAndKernelType...>>::type;

  void operator()(const char* op_type, const char* library_type) const {
157
    RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
Y
yuyang18 已提交
158 159 160 161 162 163 164 165 166 167

    constexpr auto size =
        std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
    OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
                               DataTypeAndKernelType...>
        func;
    func(op_type, library_type);
  }
};

168 169 170
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
171 172 173 174 175 176
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

177 178 179 180 181 182 183 184
/*
  The variadic arguments should be class types derived from one of the
  following classes:
    OpProtoAndCheckerMaker
    GradOpDescMakerBase
    VarTypeInference
    InferShapeBase
*/
Y
Yu Yang 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
#define REGISTER_OPERATOR(op_type, op_class, ...)                      \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                      \
      __reg_op__##op_type,                                             \
      "REGISTER_OPERATOR must be called in global namespace");         \
  class _OpClass_##op_type##_ : public op_class {                      \
   public:                                                             \
    DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_);                     \
    DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class);            \
  };                                                                   \
  static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
                                                ##__VA_ARGS__>         \
      __op_registrar_##op_type##__(#op_type);                          \
  int TouchOpRegistrar_##op_type() {                                   \
    __op_registrar_##op_type##__.Touch();                              \
    return 0;                                                          \
Y
Yu Yang 已提交
200 201
  }

F
WIP  
fengjiayi 已提交
202
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
Y
Yu Yang 已提交
203
  REGISTER_OPERATOR(op_type, op_class, op_maker_class)
D
dongzhihong 已提交
204

D
dongzhihong 已提交
205
/**
206
 * Macro to register OperatorKernel.
D
dongzhihong 已提交
207
 */
208
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)        \
209
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
210
      __reg_op_kernel_##op_type##_##library_type##__,                      \
211 212
      "REGISTER_OP_KERNEL must be called in global namespace");            \
  static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__>  \
213 214 215 216
      __op_kernel_registrar_##op_type##_##library_type##__(#op_type,       \
                                                           #library_type); \
  int TouchOpKernelRegistrar_##op_type##_##library_type() {                \
    __op_kernel_registrar_##op_type##_##library_type##__.Touch();          \
217
    return 0;                                                              \
F
fengjiayi 已提交
218
  }
D
dongzhihong 已提交
219

Q
QI JUN 已提交
220
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
D
dzhwinter 已提交
221
  REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
F
fengjiayi 已提交
222

F
fengjiayi 已提交
223 224
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
225

Y
yuyang18 已提交
226 227 228 229 230 231 232 233 234 235 236 237
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...)      \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                           \
      __reg_op_kernel_##op_type##_##library_type##__,                       \
      "REGISTER_OP_KERNEL_EX must be called in global namespace");          \
  static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
      __op_kernel_registrar_##op_type##_##library_type##__(#op_type,        \
                                                           #library_type);  \
  int TouchOpKernelRegistrar_##op_type##_##library_type() {                 \
    __op_kernel_registrar_##op_type##_##library_type##__.Touch();           \
    return 0;                                                               \
  }

238 239
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)                 \
  REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
Y
yuyang18 已提交
240 241
                        __VA_ARGS__)

242
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
Y
yuyang18 已提交
243 244
  REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)

245
/**
246 247
 * Macro to mark what Operator and Kernel
 * we will use and tell the compiler to
248 249
 * link them into target.
 */
F
fengjiayi 已提交
250 251 252 253 254 255
#define USE_OP_ITSELF(op_type)                                    \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                 \
      __use_op_itself_##op_type,                                  \
      "USE_OP_ITSELF must be called in global namespace");        \
  extern int TouchOpRegistrar_##op_type();                        \
  static int use_op_itself_##op_type##_ __attribute__((unused)) = \
F
Fix bug  
fengjiayi 已提交
256
      TouchOpRegistrar_##op_type()
F
fengjiayi 已提交
257

258 259 260 261 262 263 264 265
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE)               \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                 \
      __use_op_kernel_##op_type##_##LIBRARY_TYPE##__,             \
      "USE_OP_DEVICE_KERNEL must be in global namespace");        \
  extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \
  static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_          \
      __attribute__((unused)) =                                   \
          TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE()
Y
Yu Yang 已提交
266

267 268
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
Y
Yu Yang 已提交
269

270
#ifndef PADDLE_WITH_CUDA
271
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
272
#else
273 274
#define USE_OP_KERNEL(op_type)        \
  USE_OP_DEVICE_KERNEL(op_type, CPU); \
Q
QI JUN 已提交
275
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
Y
Yu Yang 已提交
276
#endif
277

278 279
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);

F
WIP  
fengjiayi 已提交
280 281 282
#define USE_CPU_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);        \
  USE_OP_DEVICE_KERNEL(op_type, CPU);
283

Q
QI JUN 已提交
284 285 286
#define USE_CUDA_ONLY_OP(op_type) \
  USE_OP_ITSELF(op_type);         \
  USE_OP_DEVICE_KERNEL(op_type, CUDA)
D
Dong Zhihong 已提交
287

F
WIP  
fengjiayi 已提交
288 289 290
#define USE_OP(op_type)   \
  USE_OP_ITSELF(op_type); \
  USE_OP_KERNEL(op_type)
291

292 293
}  // namespace framework
}  // namespace paddle