op_registry.h 12.6 KB
Newer Older
1 2
#pragma once

3
#include <algorithm>
Y
Yu Yang 已提交
4
#include <type_traits>
5 6
#include <unordered_map>
#include <unordered_set>
Q
Qiao Longfei 已提交
7
#include "paddle/framework/attr_checker.h"
8 9
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
Q
Qiao Longfei 已提交
10
#include "paddle/framework/operator.h"
D
dongzhihong 已提交
11
#include "paddle/framework/scope.h"
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

namespace paddle {
namespace framework {

// helper class to set attribute type
struct AttrTypeHelper {
  template <typename T>
  static void SetAttrType(AttrProto* attr);

  static Attribute GetAttrValue(const AttrDesc& attr_desc) {
    switch (attr_desc.type()) {
      case paddle::framework::AttrType::INT: {
        return attr_desc.i();
      }
      case paddle::framework::AttrType::FLOAT: {
        return attr_desc.f();
      }
      case paddle::framework::AttrType::STRING: {
        return attr_desc.s();
      }
      case paddle::framework::AttrType::INTS: {
        std::vector<int> val(attr_desc.ints_size());
        for (int i = 0; i < attr_desc.ints_size(); ++i) {
          val[i] = attr_desc.ints(i);
        }
        return val;
      }
      case paddle::framework::AttrType::FLOATS: {
        std::vector<float> val(attr_desc.floats_size());
        for (int i = 0; i < attr_desc.floats_size(); ++i) {
          val[i] = attr_desc.floats(i);
        }
        return val;
      }
      case paddle::framework::AttrType::STRINGS: {
        std::vector<std::string> val(attr_desc.strings_size());
        for (int i = 0; i < attr_desc.strings_size(); ++i) {
          val[i] = attr_desc.strings(i);
        }
        return val;
      }
    }
    PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
    return boost::blank();
  }
};

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
 public:
  OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : proto_(proto), op_checker_(op_checker) {}

65 66
  ~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); }

67
 protected:
68 69
  void AddInput(const std::string& name, const std::string& comment,
                bool multiple = false) {
70
    auto input = proto_->mutable_inputs()->Add();
71 72
    *input->mutable_name() = name;
    *input->mutable_comment() = comment;
73 74 75 76 77 78 79 80
    input->set_multiple(multiple);
    if (multiple) {
      SetHasMultipleInput();
    }
  }

  void AddInputs(const std::string& name, const std::string& comment) {
    AddInput(name, comment, true);
81 82
  }

83 84
  void AddOutput(const std::string& name, const std::string& comment,
                 bool temporary = false, bool multiple = false) {
85
    auto output = proto_->mutable_outputs()->Add();
86 87
    *output->mutable_name() = name;
    *output->mutable_comment() = comment;
88 89 90 91 92 93 94 95 96 97 98 99 100
    output->set_multiple(multiple);
    if (multiple) {
      SetHasMultipleOutput();
    }
    output->set_temporary(temporary);
    if (temporary) {
      SetHasTemporaryOutput();
    }
  }

  void AddOutputs(const std::string& name, const std::string& comment,
                  bool temporary = false) {
    AddOutput(name, comment, temporary, true);
101 102 103 104
  }

  template <typename T>
  TypedAttrChecker<T>& AddAttr(const std::string& name,
105 106
                               const std::string& comment,
                               bool generated = false) {
107
    auto attr = proto_->mutable_attrs()->Add();
108 109
    *attr->mutable_name() = name;
    *attr->mutable_comment() = comment;
110
    attr->set_generated(generated);
111 112 113 114 115 116 117 118
    AttrTypeHelper::SetAttrType<T>(attr);
    return op_checker_->AddAttrChecker<T>(name);
  }

  void AddComment(const std::string& comment) {
    *(proto_->mutable_comment()) = comment;
  }

119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
 private:
  void SetHasMultiple(const std::string& in_out, bool* flag) {
    if (!*flag) {
      AddAttr<std::vector<int>>(in_out + "_format",
                                "The multiple index of " + in_out +
                                    "\n"
                                    R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.

e.g.
  input = ["a", "b", "c", "d", "e", "f"]
  input_format = [0, 4, 5, 6]

means
  The number of all input variables this op is six, and they are segmented into
  three inputs.

  The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
                                /*generated*/ true);
      *flag = true;
    }
  }

  void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
  void SetHasMultipleOutput() {
    SetHasMultiple("output", &has_multiple_output_);
  }

  void SetHasTemporaryOutput() {
    if (!has_temporary_output_) {
      AddAttr<std::vector<int>>("temporary_index",
                                R"DOC(The temporary index of output.

Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.

Add a mark to which output is temporary is helpful for future optimization.
)DOC",
                                /*generated*/ true)
          .SetDefault(std::vector<int>());
      has_temporary_output_ = true;
    }
  }

  void CheckNoDuplicatedAttrs() {
    std::unordered_set<std::string> names;
    size_t cnt = 0;
    for (auto& attr : proto_->attrs()) {
      names.insert(attr.name());
      ++cnt;
    }
    PADDLE_ENFORCE(names.size() == cnt,
                   "Cannot register two attribute in same name!");
  }

178 179
  OpProto* proto_;
  OpAttrChecker* op_checker_;
180 181 182
  bool has_multiple_input_{false};
  bool has_multiple_output_{false};
  bool has_temporary_output_{false};
183 184 185
};

class OpRegistry {
Q
Qiao Longfei 已提交
186
  using OpCreator = std::function<OperatorBase*()>;
187 188 189 190

 public:
  template <typename OpType, typename ProtoMakerType>
  static void RegisterOp(const std::string& op_type) {
191 192
    creators()[op_type] = [] { return new OpType; };
    OpAttrChecker& op_checker = op_checkers()[op_type];
D
dongzhihong 已提交
193
    OpProto& op_proto = protos()[op_type];
194
    ProtoMakerType(&op_proto, &op_checker);
Y
Yu Yang 已提交
195 196 197 198 199
    *op_proto.mutable_type() = op_type;
    PADDLE_ENFORCE(
        op_proto.IsInitialized(),
        "Fail to initialize %s's OpProto, because %s is not initialized",
        op_type, op_proto.InitializationErrorString());
200 201
  }

D
dongzhihong 已提交
202 203 204 205 206
  template <typename OpType>
  static void RegisterGradOp(const std::string& op_type) {
    grad_creators()[op_type] = [] { return new OpType; };
  }

Q
Qiao Longfei 已提交
207
  static OperatorPtr CreateOp(const OpDesc& op_desc) {
208
    std::string op_type = op_desc.type();
Q
Qiao Longfei 已提交
209
    OperatorPtr op(creators().at(op_type)());
Q
Qiao Longfei 已提交
210
    op->type_ = op_desc.type();
211 212 213 214 215 216 217
    op->inputs_.reserve((size_t)op_desc.inputs_size());
    std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
              std::back_inserter(op->inputs_));
    op->outputs_.reserve((size_t)op_desc.outputs_size());
    std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
              std::back_inserter(op->outputs_));
    for (auto& attr : op_desc.attrs()) {
Q
Qiao Longfei 已提交
218
      op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
219
    }
Q
Qiao Longfei 已提交
220
    op_checkers().at(op_type).Check(op->attrs_);
Q
Qiao Longfei 已提交
221
    op->Init();
222 223 224
    return op;
  }

D
dongzhihong 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
  static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) {
    OperatorPtr op_grad(grad_creators().at(op->type_)());
    op_grad->type_ = op->type_;
    op_grad->inputs_.reserve(op->inputs_.size());
    for (auto& input : op->inputs_) {
      op_grad->inputs_.emplace_back(input);
      op_grad->outputs_.emplace_back(input + "@grad");
    }
    for (auto& output : op->outputs_) {
      op_grad->inputs_.emplace_back(output);
      op_grad->inputs_.emplace_back(output + "@grad");
    }
    return op_grad;
  }

Y
Yu Yang 已提交
240 241 242 243 244
  static std::unordered_map<std::string, OpProto>& protos() {
    static std::unordered_map<std::string, OpProto> protos_;
    return protos_;
  };

245
 private:
246 247 248 249
  static std::unordered_map<std::string, OpCreator>& creators() {
    static std::unordered_map<std::string, OpCreator> creators_;
    return creators_;
  }
250

251 252 253 254
  static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
    static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
    return op_checkers_;
  };
D
dongzhihong 已提交
255 256 257 258 259

  static std::unordered_map<std::string, OpCreator>& grad_creators() {
    static std::unordered_map<std::string, OpCreator> grad_creators_;
    return grad_creators_;
  }
260
};
261 262 263 264

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
 public:
Y
Yu Yang 已提交
265
  OpRegisterHelper(const char* op_type) {
266 267 268 269
    OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
  }
};

D
dongzhihong 已提交
270 271 272 273 274 275 276 277
template <typename OpType>
class GradOpRegisterHelper {
 public:
  GradOpRegisterHelper(const char* op_type) {
    OpRegistry::RegisterGradOp<OpType>(op_type);
  }
};

278 279 280
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
281 282 283 284 285 286
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

287 288 289
/**
 * Macro to Register Operator.
 */
Y
Yu Yang 已提交
290 291 292 293 294 295 296
#define REGISTER_OP(__op_type, __op_class, __op_maker_class)                 \
  STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type,                      \
                                 "REGISTER_OP must be in global namespace"); \
  static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
      __op_register_##__op_type##__(#__op_type);                             \
  int __op_register_##__op_type##_handle__() { return 0; }

D
dongzhihong 已提交
297 298 299 300 301 302 303 304 305 306 307
/**
 * Macro to Register Operator.
 */
#define REGISTER_GRADIENT_OP(__op_type, __op_class)            \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                              \
      __reg_op__##__op_type,                                   \
      "REGISTER_GRADIENT_OP must be in global namespace");     \
  static ::paddle::framework::GradOpRegisterHelper<__op_class> \
      __op_register_##__op_type##__(#__op_type);               \
  int __op_register_##__op_type##_handle__() { return 0; }

308 309 310 311
/**
 * Macro to Register OperatorKernel.
 */
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType)      \
Y
Yu Yang 已提交
312
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
313
      __reg_op_kernel_##type##_##DEVICE_TYPE##__,                         \
Y
Yu Yang 已提交
314 315 316 317 318 319 320 321 322 323
      "REGISTER_OP_KERNEL must be in global namespace");                  \
  struct __op_kernel_register__##type##__ {                               \
    __op_kernel_register__##type##__() {                                  \
      ::paddle::framework::OperatorWithKernel::OpKernelKey key;           \
      key.place_ = PlaceType();                                           \
      ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
          .reset(new KernelType());                                       \
    }                                                                     \
  };                                                                      \
  static __op_kernel_register__##type##__ __reg_kernel_##type##__;        \
324
  int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
Y
Yu Yang 已提交
325 326 327 328 329 330 331

#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
  REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)

#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
  REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)

332 333 334 335
/**
 * Macro to mark what Operator and Kernel we will use and tell the compiler to
 * link them into target.
 */
Y
Yu Yang 已提交
336 337 338 339 340 341 342 343
#define USE_OP_WITHOUT_KERNEL(op_type)                      \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                           \
      __use_op_without_kernel_##op_type,                    \
      "USE_OP_WITHOUT_KERNEL must be in global namespace"); \
  extern int __op_register_##op_type##_handle__();          \
  static int __use_op_ptr_##op_type##_without_kernel__      \
      __attribute__((unused)) = __op_register_##op_type##_handle__()

Y
Yu Yang 已提交
344 345 346 347 348 349 350 351
#define USE_OP_KERNEL(op_type, DEVICE_TYPE)                               \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
      __use_op_kernel_##op_type##_##DEVICE_TYPE##__,                      \
      "USE_OP_KERNEL must be in global namespace");                       \
  extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
  static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__            \
      __attribute__((unused)) =                                           \
          __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
Y
Yu Yang 已提交
352

353 354
// use Operator with only cpu kernel.
#define USE_OP_CPU(op_type)       \
Y
Yu Yang 已提交
355
  USE_OP_WITHOUT_KERNEL(op_type); \
356
  USE_OP_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
357

358 359
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) USE_OP_CPU(op_type)
Y
Yu Yang 已提交
360
#else
361 362
#define USE_OP(op_type) \
  USE_OP_CPU(op_type);  \
Y
Yu Yang 已提交
363 364
  USE_OP_KERNEL(op_type, GPU)
#endif
365 366 367

}  // namespace framework
}  // namespace paddle