op_registry.h 15.7 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#pragma once

17
#include <algorithm>
18
#include <atomic>
Y
Yu Yang 已提交
19
#include <type_traits>
20 21
#include <unordered_map>
#include <unordered_set>
Y
Yi Wang 已提交
22
#include "paddle/framework/attribute.h"
F
fengjiayi 已提交
23
#include "paddle/framework/grad_op_builder.h"
24
#include "paddle/framework/op_desc.pb.h"
D
dongzhihong 已提交
25
#include "paddle/framework/scope.h"
26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace framework {

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
 public:
  OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : proto_(proto), op_checker_(op_checker) {}

36 37 38 39 40 41 42 43
  ~OpProtoAndCheckerMaker() {
    PADDLE_ENFORCE(validated_, "should call Validate after build");
  }

  void Validate() {
    validated_ = true;
    CheckNoDuplicatedInOutAttrs();
  }
44

45
 protected:
Y
Yu Yang 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  struct VariableBuilder {
    VarProto* var_;
    std::function<void()> on_multiple_;
    std::function<void()> on_temporary_;

    VariableBuilder& SetMultiple() {
      var_->set_multiple(true);
      on_multiple_();
      return *this;
    }

    VariableBuilder& SetTemporary() {
      PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
      var_->set_temporary(true);
      on_temporary_();
      return *this;
    }

    VariableBuilder& IgnoreGradient() {
      var_->set_ignore_gradient(true);
      return *this;
    }
  };

  VariableBuilder AddInput(const std::string& name,
                           const std::string& comment) {
72
    auto input = proto_->mutable_inputs()->Add();
73 74
    *input->mutable_name() = name;
    *input->mutable_comment() = comment;
Y
Yu Yang 已提交
75 76
    return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
                           nullptr};
77 78
  }

Y
Yu Yang 已提交
79 80
  VariableBuilder AddOutput(const std::string& name,
                            const std::string& comment) {
81
    auto output = proto_->mutable_outputs()->Add();
82 83
    *output->mutable_name() = name;
    *output->mutable_comment() = comment;
Y
Yu Yang 已提交
84 85
    return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
                           [=] { this->SetHasTemporaryOutput(); }};
86 87 88 89
  }

  template <typename T>
  TypedAttrChecker<T>& AddAttr(const std::string& name,
90 91
                               const std::string& comment,
                               bool generated = false) {
92
    auto attr = proto_->mutable_attrs()->Add();
93 94
    *attr->mutable_name() = name;
    *attr->mutable_comment() = comment;
95
    attr->set_generated(generated);
Y
Yi Wang 已提交
96
    attr->set_type(AttrTypeID<T>());
97 98 99 100 101 102 103
    return op_checker_->AddAttrChecker<T>(name);
  }

  void AddComment(const std::string& comment) {
    *(proto_->mutable_comment()) = comment;
  }

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
 private:
  void SetHasMultiple(const std::string& in_out, bool* flag) {
    if (!*flag) {
      AddAttr<std::vector<int>>(in_out + "_format",
                                "The multiple index of " + in_out +
                                    "\n"
                                    R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.

e.g.
  input = ["a", "b", "c", "d", "e", "f"]
  input_format = [0, 4, 5, 6]

means
  The number of all input variables this op is six, and they are segmented into
  three inputs.

  The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
                                /*generated*/ true);
      *flag = true;
    }
  }

  void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
  void SetHasMultipleOutput() {
    SetHasMultiple("output", &has_multiple_output_);
  }

  void SetHasTemporaryOutput() {
    if (!has_temporary_output_) {
      AddAttr<std::vector<int>>("temporary_index",
                                R"DOC(The temporary index of output.

Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.

Add a mark to which output is temporary is helpful for future optimization.
)DOC",
                                /*generated*/ true)
          .SetDefault(std::vector<int>());
      has_temporary_output_ = true;
    }
  }

152
  void CheckNoDuplicatedInOutAttrs() {
153
    std::unordered_set<std::string> names;
154 155 156 157
    auto checker = [&](const std::string& name) {
      PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
      names.insert(name);
    };
158
    for (auto& attr : proto_->attrs()) {
159 160 161 162 163 164 165
      checker(attr.name());
    }
    for (auto& input : proto_->inputs()) {
      checker(input.name());
    }
    for (auto& output : proto_->outputs()) {
      checker(output.name());
166 167 168
    }
  }

169 170
  OpProto* proto_;
  OpAttrChecker* op_checker_;
171
  bool validated_{false};
172 173 174
  bool has_multiple_input_{false};
  bool has_multiple_output_{false};
  bool has_temporary_output_{false};
175 176 177
};

class OpRegistry {
Q
Qiao Longfei 已提交
178
  using OpCreator = std::function<OperatorBase*()>;
Y
Yu Yang 已提交
179
  using VarIndexMap = std::unordered_map<std::string, int>;
Y
Yu Yang 已提交
180
  using VarNameList = std::vector<std::string>;
181 182 183 184

 public:
  template <typename OpType, typename ProtoMakerType>
  static void RegisterOp(const std::string& op_type) {
185
    op_creators()[op_type] = [] { return new OpType; };
186
    OpAttrChecker& op_checker = op_checkers()[op_type];
D
dongzhihong 已提交
187
    OpProto& op_proto = protos()[op_type];
188 189
    auto maker = ProtoMakerType(&op_proto, &op_checker);
    maker.Validate();
Y
Yu Yang 已提交
190 191 192 193 194
    *op_proto.mutable_type() = op_type;
    PADDLE_ENFORCE(
        op_proto.IsInitialized(),
        "Fail to initialize %s's OpProto, because %s is not initialized",
        op_type, op_proto.InitializationErrorString());
Y
Yu Yang 已提交
195 196 197 198 199 200 201 202 203 204 205

    VarIndexMaps()[op_type].reset(new VarIndexMap());
    auto& varmap = *VarIndexMaps()[op_type];
    int idx = 0;
    for (auto& var : op_proto.inputs()) {
      varmap[var.name()] = idx++;
    }
    idx = 0;
    for (auto& var : op_proto.outputs()) {
      varmap[var.name()] = idx++;
    }
206 207
  }

208 209 210 211 212
  template <typename GradOpType>
  static void RegisterGradOp(const std::string& op_type,
                             const std::string& grad_op_type) {
    op_creators()[grad_op_type] = [] { return new GradOpType; };
    grad_ops()[op_type] = grad_op_type;
F
fengjiayi 已提交
213 214
  }

Y
Yu Yang 已提交
215 216 217 218
  static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
                                                const VarNameList& inputs,
                                                const VarNameList& outputs,
                                                const AttributeMap& attrs) {
219 220
    auto op_create_it = op_creators().find(type);
    PADDLE_ENFORCE(op_create_it != op_creators().end(),
F
fengjiayi 已提交
221
                   "Operator %s cannot be found.", type);
222

Y
Yu Yang 已提交
223 224 225 226
    auto op = op_create_it->second();
    op->type_ = type;
    op->inputs_ = inputs;
    op->outputs_ = outputs;
F
fengjiayi 已提交
227

Y
Yu Yang 已提交
228 229
    op->attrs_ = attrs;
    op_checkers().at(type).Check(op->attrs_);
230

Y
Yu Yang 已提交
231
    GenerateTempVariableName(op);
232

Y
Yu Yang 已提交
233
    {
Y
Yu Yang 已提交
234
      auto var_index_it = VarIndexMaps().find(type);
Y
Yu Yang 已提交
235 236 237 238
      if (var_index_it != VarIndexMaps().end()) {
        op->in_out_idxs_ = var_index_it->second;
      }
    }
Y
Yu Yang 已提交
239

Q
Qiao Longfei 已提交
240
    op->Init();
Y
Yu Yang 已提交
241
    return std::shared_ptr<OperatorBase>(op);
242 243
  }

Y
Yu Yang 已提交
244
  static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
Y
Yu Yang 已提交
245 246
    std::vector<std::string> inputs;
    inputs.reserve((size_t)op_desc.inputs_size());
247
    std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
Y
Yu Yang 已提交
248 249 250 251
              std::back_inserter(inputs));

    std::vector<std::string> outputs;
    outputs.reserve((size_t)op_desc.outputs_size());
252
    std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
Y
Yu Yang 已提交
253 254 255
              std::back_inserter(outputs));

    AttributeMap attrs;
256
    for (auto& attr : op_desc.attrs()) {
Y
Yi Wang 已提交
257
      attrs[attr.name()] = GetAttrValue(attr);
258
    }
Y
Yu Yang 已提交
259 260

    return CreateOp(op_desc.type(), inputs, outputs, attrs);
261 262
  }

Y
Yu Yang 已提交
263 264
  static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
    PADDLE_ENFORCE(!op.IsNetOp(),
Y
Yu Yang 已提交
265
                   "Use framework::Backward to get backward ops");
266
    std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
F
fengjiayi 已提交
267 268
    grad_op->Init();
    return grad_op;
D
dongzhihong 已提交
269 270
  }

Y
Yu Yang 已提交
271 272 273
  static std::unordered_map<std::string, OpProto>& protos() {
    static std::unordered_map<std::string, OpProto> protos_;
    return protos_;
L
liaogang 已提交
274
  }
Y
Yu Yang 已提交
275

276 277 278
  static std::unordered_map<std::string, std::string>& grad_ops() {
    static std::unordered_map<std::string, std::string> grad_ops_;
    return grad_ops_;
279 280
  }

Y
Yu Yang 已提交
281 282 283 284 285 286
  static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
  VarIndexMaps() {
    static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
    return maps_;
  }

287 288 289
  static std::unordered_map<std::string, OpCreator>& op_creators() {
    static std::unordered_map<std::string, OpCreator> op_creators_;
    return op_creators_;
F
fengjiayi 已提交
290 291
  }

292
 private:
F
fengjiayi 已提交
293 294 295
  static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
    static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
    return op_checkers_;
L
liaogang 已提交
296
  }
F
fengjiayi 已提交
297

298
  static void GenerateTempVariableName(OperatorBase* op) {
299 300
    static std::atomic<size_t> gUniqId(0UL);
    for (auto& outname : op->outputs_) {
301
      if (outname == kTempVarName) {
302
        outname += op->type_;
303 304 305 306 307
        outname += "@";
        outname += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
308
};
309

310
class Registrar {};
F
fengjiayi 已提交
311

312
template <typename OpType, typename ProtoMakerType>
F
fengjiayi 已提交
313
class OpRegistrar : public Registrar {
314
 public:
F
fengjiayi 已提交
315
  explicit OpRegistrar(const char* op_type) {
316 317 318 319
    OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
  }
};

320
template <typename GradOpType>
F
fengjiayi 已提交
321
class GradOpRegistrar : public Registrar {
D
dongzhihong 已提交
322
 public:
F
fengjiayi 已提交
323
  GradOpRegistrar(const char* op_type, const char* grad_op_type) {
324
    OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
D
dongzhihong 已提交
325 326 327
  }
};

F
fengjiayi 已提交
328 329 330 331 332 333 334 335 336 337 338
template <typename PlaceType, typename KernelType>
class OpKernelRegistrar : public Registrar {
 public:
  explicit OpKernelRegistrar(const char* op_type) {
    ::paddle::framework::OperatorWithKernel::OpKernelKey key;
    key.place_ = PlaceType();
    ::paddle::framework::OperatorWithKernel::AllOpKernels()[op_type][key].reset(
        new KernelType);
  }
};

339 340 341
/**
 * check if MACRO is used in GLOBAL NAMESPACE.
 */
Y
Yu Yang 已提交
342 343 344 345 346 347
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

348 349 350
/**
 * Macro to Register Operator.
 */
F
fengjiayi 已提交
351 352 353
#define REGISTER_OP(op_type, op_class, op_maker_class)                        \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                             \
      __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
F
fengjiayi 已提交
354
  int TouchOpRegistrar_##op_type() {                                          \
355 356
    static ::paddle::framework::OpRegistrar<op_class, op_maker_class>         \
        __op_registrar_##op_type##__(#op_type);                               \
F
fengjiayi 已提交
357 358
    return 0;                                                                 \
  }
Y
Yu Yang 已提交
359

D
dongzhihong 已提交
360
/**
F
fengjiayi 已提交
361
 * Macro to Register Gradient Operator.
D
dongzhihong 已提交
362
 */
363 364 365 366 367 368 369 370 371
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class)             \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
      __reg_gradient_op__##op_type##_##grad_op_type,                           \
      "REGISTER_GRADIENT_OP must be called in global namespace");              \
  int TouchOpGradientRegistrar_##op_type() {                                   \
    static ::paddle::framework::GradOpRegistrar<grad_op_class>                 \
        __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type,       \
                                                               #grad_op_type); \
    return 0;                                                                  \
F
fengjiayi 已提交
372
  }
D
dongzhihong 已提交
373

D
dongzhihong 已提交
374
/**
F
fengjiayi 已提交
375
 * Macro to Register OperatorKernel.
D
dongzhihong 已提交
376
 */
377 378 379 380 381 382 383 384
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...)          \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                           \
      __reg_op_kernel_##op_type##_##DEVICE_TYPE##__,                        \
      "REGISTER_OP_KERNEL must be called in global namespace");             \
  int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() {                  \
    static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
        __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type);      \
    return 0;                                                               \
F
fengjiayi 已提交
385
  }
D
dongzhihong 已提交
386

387
/**
F
fengjiayi 已提交
388
 * Macro to Forbid user register Gradient Operator.
389
 */
F
fengjiayi 已提交
390 391 392 393 394
#define NO_GRADIENT(op_type)                           \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                      \
      __reg_gradient_op__##op_type##_##op_type##_grad, \
      "NO_GRADIENT must be called in global namespace")

F
fengjiayi 已提交
395 396
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
F
fengjiayi 已提交
397

F
fengjiayi 已提交
398 399
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
  REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
Y
Yu Yang 已提交
400

401 402 403 404
/**
 * Macro to mark what Operator and Kernel we will use and tell the compiler to
 * link them into target.
 */
F
fengjiayi 已提交
405 406 407 408 409 410
#define USE_OP_ITSELF(op_type)                                    \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                 \
      __use_op_itself_##op_type,                                  \
      "USE_OP_ITSELF must be called in global namespace");        \
  extern int TouchOpRegistrar_##op_type();                        \
  static int use_op_itself_##op_type##_ __attribute__((unused)) = \
F
Fix bug  
fengjiayi 已提交
411
      TouchOpRegistrar_##op_type()
F
fengjiayi 已提交
412

413 414 415 416 417 418
// TODO(jiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
F
fengjiayi 已提交
419 420 421 422 423 424 425
#define USE_OP_GRADIENT(op_type)                                    \
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                   \
      __use_op_gradient_##op_type,                                  \
      "USE_OP_GRADIENT must be called in global namespace");        \
  extern int TouchOpGradientRegistrar_##op_type();                  \
  static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
      TouchOpGradientRegistrar_##op_type()
426 427 428
#else
#define USE_OP_GRADIENT(op_type)
#endif
F
fengjiayi 已提交
429

430
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE)               \
F
fengjiayi 已提交
431 432
  STATIC_ASSERT_GLOBAL_NAMESPACE(                                \
      __use_op_kernel_##op_type##_##DEVICE_TYPE##__,             \
433
      "USE_OP_DEVICE_KERNEL must be in global namespace");       \
F
fengjiayi 已提交
434 435 436
  extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \
  static int use_op_kernel_##op_type##_##DEVICE_TYPE##_          \
      __attribute__((unused)) =                                  \
F
Fix bug  
fengjiayi 已提交
437
          TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
Y
Yu Yang 已提交
438

F
fengjiayi 已提交
439
#ifdef PADDLE_ONLY_CPU
440
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
Y
Yu Yang 已提交
441
#else
442 443 444
#define USE_OP_KERNEL(op_type)        \
  USE_OP_DEVICE_KERNEL(op_type, CPU); \
  USE_OP_DEVICE_KERNEL(op_type, GPU)
Y
Yu Yang 已提交
445
#endif
446

447 448 449 450 451 452 453 454 455 456 457 458 459
#define USE_NO_GRAD_OP(op_type) \
  USE_OP_ITSELF(op_type);       \
  USE_OP_KERNEL(op_type)

#define USE_CPU_OP(op_type)           \
  USE_OP_ITSELF(op_type);             \
  USE_OP_DEVICE_KERNEL(op_type, CPU); \
  USE_OP_GRADIENT(op_type)

#define USE_OP(op_type)    \
  USE_NO_GRAD_OP(op_type); \
  USE_OP_GRADIENT(op_type)

460 461
}  // namespace framework
}  // namespace paddle