op_registry.h 13.1 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#pragma once

17
#include <algorithm>
18
#include <atomic>
Y
Yu Yang 已提交
19
#include <type_traits>
20 21
#include <unordered_map>
#include <unordered_set>
Y
Yi Wang 已提交
22
#include "paddle/framework/attribute.h"
Y
Yu Yang 已提交
23
#include "paddle/framework/framework.pb.h"
F
fengjiayi 已提交
24
#include "paddle/framework/grad_op_builder.h"
D
dongzhihong 已提交
25
#include "paddle/framework/scope.h"
26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace framework {

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
 public:
  OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : proto_(proto), op_checker_(op_checker) {}

36 37 38 39 40 41 42 43
  ~OpProtoAndCheckerMaker() {
    PADDLE_ENFORCE(validated_, "should call Validate after build");
  }

  void Validate() {
    validated_ = true;
    CheckNoDuplicatedInOutAttrs();
  }
44

45
 protected:
Y
Yu Yang 已提交
46
  struct VariableBuilder {
Y
Yu Yang 已提交
47
    OpProto::Var* var_;
Y
Yu Yang 已提交
48 49

    VariableBuilder& SetMultiple() {
Y
Yu Yang 已提交
50
      var_->set_duplicable(true);
Y
Yu Yang 已提交
51 52 53 54
      return *this;
    }

    VariableBuilder& SetTemporary() {
Y
Yu Yang 已提交
55
      var_->set_intermediate(true);
Y
Yu Yang 已提交
56 57 58 59
      return *this;
    }

    VariableBuilder& IgnoreGradient() {
Y
Yu Yang 已提交
60
      var_->set_no_gradient(true);
Y
Yu Yang 已提交
61 62 63 64 65 66
      return *this;
    }
  };

  VariableBuilder AddInput(const std::string& name,
                           const std::string& comment) {
67
    auto input = proto_->mutable_inputs()->Add();
68 69
    *input->mutable_name() = name;
    *input->mutable_comment() = comment;
Y
Yu Yang 已提交
70
    return VariableBuilder{input};
71 72
  }

Y
Yu Yang 已提交
73 74
  VariableBuilder AddOutput(const std::string& name,
                            const std::string& comment) {
75
    auto output = proto_->mutable_outputs()->Add();
76 77
    *output->mutable_name() = name;
    *output->mutable_comment() = comment;
Y
Yu Yang 已提交
78
    return VariableBuilder{output};
79 80 81 82
  }

  template <typename T>
  TypedAttrChecker<T>& AddAttr(const std::string& name,
83 84
                               const std::string& comment,
                               bool generated = false) {
85
    auto attr = proto_->mutable_attrs()->Add();
86 87
    *attr->mutable_name() = name;
    *attr->mutable_comment() = comment;
88
    attr->set_generated(generated);
Y
Yi Wang 已提交
89
    attr->set_type(AttrTypeID<T>());
90 91 92 93 94 95 96
    return op_checker_->AddAttrChecker<T>(name);
  }

  void AddComment(const std::string& comment) {
    *(proto_->mutable_comment()) = comment;
  }

97
 private:
98
  void CheckNoDuplicatedInOutAttrs() {
99
    std::unordered_set<std::string> names;
100 101 102 103
    auto checker = [&](const std::string& name) {
      PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
      names.insert(name);
    };
104
    for (auto& attr : proto_->attrs()) {
105 106 107 108 109 110 111
      checker(attr.name());
    }
    for (auto& input : proto_->inputs()) {
      checker(input.name());
    }
    for (auto& output : proto_->outputs()) {
      checker(output.name());
112 113 114
    }
  }

115 116
  OpProto* proto_;
  OpAttrChecker* op_checker_;
117
  bool validated_{false};
118 119 120
};

class OpRegistry {
Q
Qiao Longfei 已提交
121
  using OpCreator = std::function<OperatorBase*()>;
Y
Yu Yang 已提交
122
  using VarIndexMap = std::unordered_map<std::string, int>;
Y
Yu Yang 已提交
123
  using VarNameMap = std::unordered_map<std::string, std::vector<std::string>>;
124 125 126 127

 public:
  template <typename OpType, typename ProtoMakerType>
  static void RegisterOp(const std::string& op_type) {
128
    op_creators()[op_type] = [] { return new OpType; };
129
    OpAttrChecker& op_checker = op_checkers()[op_type];
D
dongzhihong 已提交
130
    OpProto& op_proto = protos()[op_type];
131 132
    auto maker = ProtoMakerType(&op_proto, &op_checker);
    maker.Validate();
Y
Yu Yang 已提交
133 134 135 136 137
    *op_proto.mutable_type() = op_type;
    PADDLE_ENFORCE(
        op_proto.IsInitialized(),
        "Fail to initialize %s's OpProto, because %s is not initialized",
        op_type, op_proto.InitializationErrorString());
Y
Yu Yang 已提交
138 139 140 141 142 143 144 145 146 147 148

    VarIndexMaps()[op_type].reset(new VarIndexMap());
    auto& varmap = *VarIndexMaps()[op_type];
    int idx = 0;
    for (auto& var : op_proto.inputs()) {
      varmap[var.name()] = idx++;
    }
    idx = 0;
    for (auto& var : op_proto.outputs()) {
      varmap[var.name()] = idx++;
    }
149 150
  }

151 152 153 154 155
  template <typename GradOpType>
  static void RegisterGradOp(const std::string& op_type,
                             const std::string& grad_op_type) {
    op_creators()[grad_op_type] = [] { return new GradOpType; };
    grad_ops()[op_type] = grad_op_type;
F
fengjiayi 已提交
156 157
  }

Y
Yu Yang 已提交
158
  static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
Y
Yu Yang 已提交
159 160
                                                const VarNameMap& inputs,
                                                const VarNameMap& outputs,
Y
Yu Yang 已提交
161
                                                const AttributeMap& attrs) {
162 163
    auto op_create_it = op_creators().find(type);
    PADDLE_ENFORCE(op_create_it != op_creators().end(),
F
fengjiayi 已提交
164
                   "Operator %s cannot be found.", type);
165

Y
Yu Yang 已提交
166 167 168 169
    auto op = op_create_it->second();
    op->type_ = type;
    op->inputs_ = inputs;
    op->outputs_ = outputs;
F
fengjiayi 已提交
170

Y
Yu Yang 已提交
171 172
    op->attrs_ = attrs;
    op_checkers().at(type).Check(op->attrs_);
173

Y
Yu Yang 已提交
174
    GenerateTempVariableName(op);
175

Q
Qiao Longfei 已提交
176
    op->Init();
Y
Yu Yang 已提交
177
    return std::shared_ptr<OperatorBase>(op);
178 179
  }

Y
Yu Yang 已提交
180
  static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
Y
Yu Yang 已提交
181 182
    VarNameMap inputs;
    for (auto& input : op_desc.inputs()) {
183 184
      auto& var_names = inputs[input.parameter()];
      auto& var_names_in_proto = input.arguments();
Y
Yu Yang 已提交
185 186 187 188
      var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
      std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
                std::back_inserter(var_names));
    }
Y
Yu Yang 已提交
189

Y
Yu Yang 已提交
190 191
    VarNameMap outputs;
    for (auto& output : op_desc.outputs()) {
192 193
      auto& var_names = outputs[output.parameter()];
      auto& var_names_in_proto = output.arguments();
Y
Yu Yang 已提交
194 195 196 197
      var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
      std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
                std::back_inserter(var_names));
    }
Y
Yu Yang 已提交
198 199

    AttributeMap attrs;
200
    for (auto& attr : op_desc.attrs()) {
Y
Yi Wang 已提交
201
      attrs[attr.name()] = GetAttrValue(attr);
202
    }
Y
Yu Yang 已提交
203 204

    return CreateOp(op_desc.type(), inputs, outputs, attrs);
205 206
  }

Y
Yu Yang 已提交
207 208
  static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
    PADDLE_ENFORCE(!op.IsNetOp(),
Y
Yu Yang 已提交
209
                   "Use framework::Backward to get backward ops");
210
    std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
F
fengjiayi 已提交
211 212
    grad_op->Init();
    return grad_op;
D
dongzhihong 已提交
213 214
  }

Y
Yu Yang 已提交
215 216 217
  static std::unordered_map<std::string, OpProto>& protos() {
    static std::unordered_map<std::string, OpProto> protos_;
    return protos_;
L
liaogang 已提交
218
  }
Y
Yu Yang 已提交
219

220 221 222
  static std::unordered_map<std::string, std::string>& grad_ops() {
    static std::unordered_map<std::string, std::string> grad_ops_;
    return grad_ops_;
223 224
  }

Y
Yu Yang 已提交
225 226 227 228 229 230
  static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
  VarIndexMaps() {
    static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
    return maps_;
  }

231 232 233
  static std::unordered_map<std::string, OpCreator>& op_creators() {
    static std::unordered_map<std::string, OpCreator> op_creators_;
    return op_creators_;
F
fengjiayi 已提交
234 235
  }

236
 private:
F
fengjiayi 已提交
237 238 239
  static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
    static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
    return op_checkers_;
L
liaogang 已提交
240
  }
F
fengjiayi 已提交
241

242
  static void GenerateTempVariableName(OperatorBase* op) {
243
    static std::atomic<size_t> gUniqId(0UL);
Y
Yu Yang 已提交
244 245 246 247 248 249 250
    for (auto& output : op->outputs_) {
      for (auto& output_name : output.second) {
        if (output_name == kTempVarName) {
          output_name += op->type_;
          output_name += "@";
          output_name += std::to_string(gUniqId.fetch_add(1));
        }
251 252 253
      }
    }
  }
254
};
255 256 257 258

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
 public:
L
liaogang 已提交
259
  explicit OpRegisterHelper(const char* op_type) {
260 261 262 263
    OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
  }
};

264
template <typename GradOpType>
D
dongzhihong 已提交
265 266
class GradOpRegisterHelper {
 public:
267 268
  GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
    OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
D
dongzhihong 已提交
269 270 271
  }
};

272 273 274
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
275 276 277 278 279 280
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

281 282 283
/**
 * Macro to Register Operator.
 */
Y
Yu Yang 已提交
284 285 286 287 288 289 290
#define REGISTER_OP(__op_type, __op_class, __op_maker_class)                 \
  STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type,                      \
                                 "REGISTER_OP must be in global namespace"); \
  static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
      __op_register_##__op_type##__(#__op_type);                             \
  int __op_register_##__op_type##_handle__() { return 0; }

D
dongzhihong 已提交
291
/**
F
fengjiayi 已提交
292
 * Macro to Register Gradient Operator.
D
dongzhihong 已提交
293
 */
294 295 296 297 298 299 300 301 302 303
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class)       \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
      __reg_gradient_op__##__op_type##__grad_op_type,                          \
      "REGISTER_GRADIENT_OP must be in global namespace");                     \
  static ::paddle::framework::GradOpRegisterHelper<__grad_op_class>            \
      __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type,       \
                                                             #__grad_op_type); \
  int __op_gradient_register_##__op_type##__grad_op_type##_handle__() {        \
    return 0;                                                                  \
  }
D
dongzhihong 已提交
304

D
dongzhihong 已提交
305 306 307 308 309 310 311 312
/**
 * Macro to Forbid user register Gradient Operator.
 */
#define NO_GRADIENT(__op_type)                          \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                       \
      __reg_gradient_op__##__op_type##__op_type##_grad, \
      "NO_GRADIENT must be in global namespace")

313 314 315
/**
 * Macro to Register OperatorKernel.
 */
316
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...)             \
Y
Yu Yang 已提交
317
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
318
      __reg_op_kernel_##type##_##DEVICE_TYPE##__,                         \
Y
Yu Yang 已提交
319
      "REGISTER_OP_KERNEL must be in global namespace");                  \
Q
qijun 已提交
320 321
  struct __op_kernel_register__##type##__##DEVICE_TYPE##__ {              \
    __op_kernel_register__##type##__##DEVICE_TYPE##__() {                 \
Y
Yu Yang 已提交
322 323 324
      ::paddle::framework::OperatorWithKernel::OpKernelKey key;           \
      key.place_ = PlaceType();                                           \
      ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
325
          .reset(new __VA_ARGS__());                                      \
Y
Yu Yang 已提交
326 327
    }                                                                     \
  };                                                                      \
Q
qijun 已提交
328 329
  static __op_kernel_register__##type##__##DEVICE_TYPE##__                \
      __reg_kernel_##type##__##DEVICE_TYPE##__;                           \
330
  int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
Y
Yu Yang 已提交
331

332 333 334
// (type, KernelType)
#define REGISTER_OP_GPU_KERNEL(type, ...) \
  REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
335

336 337 338
// (type, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, ...) \
  REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
339

340 341 342 343
/**
 * Macro to mark what Operator and Kernel we will use and tell the compiler to
 * link them into target.
 */
Y
Yu Yang 已提交
344 345 346 347 348 349 350 351
#define USE_OP_WITHOUT_KERNEL(op_type)                      \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                           \
      __use_op_without_kernel_##op_type,                    \
      "USE_OP_WITHOUT_KERNEL must be in global namespace"); \
  extern int __op_register_##op_type##_handle__();          \
  static int __use_op_ptr_##op_type##_without_kernel__      \
      __attribute__((unused)) = __op_register_##op_type##_handle__()

Y
Yu Yang 已提交
352 353 354 355 356 357 358 359
#define USE_OP_KERNEL(op_type, DEVICE_TYPE)                               \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                         \
      __use_op_kernel_##op_type##_##DEVICE_TYPE##__,                      \
      "USE_OP_KERNEL must be in global namespace");                       \
  extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
  static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__            \
      __attribute__((unused)) =                                           \
          __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
Y
Yu Yang 已提交
360

361 362
// use Operator with only cpu kernel.
#define USE_OP_CPU(op_type)       \
Y
Yu Yang 已提交
363
  USE_OP_WITHOUT_KERNEL(op_type); \
364
  USE_OP_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
365

366 367
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) USE_OP_CPU(op_type)
Y
Yu Yang 已提交
368
#else
369 370
#define USE_OP(op_type) \
  USE_OP_CPU(op_type);  \
Y
Yu Yang 已提交
371 372
  USE_OP_KERNEL(op_type, GPU)
#endif
373 374 375

}  // namespace framework
}  // namespace paddle