op_registry.h 5.1 KB
Newer Older
1 2
#pragma once

3
#include <algorithm>
Q
Qiao Longfei 已提交
4
#include "paddle/framework/attr_checker.h"
5 6
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
Q
Qiao Longfei 已提交
7
#include "paddle/framework/operator.h"
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

namespace paddle {
namespace framework {

// helper class to set attribute type
struct AttrTypeHelper {
  template <typename T>
  static void SetAttrType(AttrProto* attr);

  static Attribute GetAttrValue(const AttrDesc& attr_desc) {
    switch (attr_desc.type()) {
      case paddle::framework::AttrType::INT: {
        return attr_desc.i();
      }
      case paddle::framework::AttrType::FLOAT: {
        return attr_desc.f();
      }
      case paddle::framework::AttrType::STRING: {
        return attr_desc.s();
      }
      case paddle::framework::AttrType::INTS: {
        std::vector<int> val(attr_desc.ints_size());
        for (int i = 0; i < attr_desc.ints_size(); ++i) {
          val[i] = attr_desc.ints(i);
        }
        return val;
      }
      case paddle::framework::AttrType::FLOATS: {
        std::vector<float> val(attr_desc.floats_size());
        for (int i = 0; i < attr_desc.floats_size(); ++i) {
          val[i] = attr_desc.floats(i);
        }
        return val;
      }
      case paddle::framework::AttrType::STRINGS: {
        std::vector<std::string> val(attr_desc.strings_size());
        for (int i = 0; i < attr_desc.strings_size(); ++i) {
          val[i] = attr_desc.strings(i);
        }
        return val;
      }
    }
    PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
    return boost::blank();
  }
};

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
 public:
  OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : proto_(proto), op_checker_(op_checker) {}

 protected:
  void AddInput(const std::string& name, const std::string& comment) {
    auto input = proto_->mutable_inputs()->Add();
64 65
    *input->mutable_name() = name;
    *input->mutable_comment() = comment;
66 67 68 69
  }

  void AddOutput(const std::string& name, const std::string& comment) {
    auto output = proto_->mutable_outputs()->Add();
70 71
    *output->mutable_name() = name;
    *output->mutable_comment() = comment;
72 73 74 75 76 77
  }

  template <typename T>
  TypedAttrChecker<T>& AddAttr(const std::string& name,
                               const std::string& comment) {
    auto attr = proto_->mutable_attrs()->Add();
78 79
    *attr->mutable_name() = name;
    *attr->mutable_comment() = comment;
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    AttrTypeHelper::SetAttrType<T>(attr);
    return op_checker_->AddAttrChecker<T>(name);
  }

  void AddType(const std::string& op_type) { proto_->set_type(op_type); }

  void AddComment(const std::string& comment) {
    *(proto_->mutable_comment()) = comment;
  }

  OpProto* proto_;
  OpAttrChecker* op_checker_;
};

class OpRegistry {
Q
Qiao Longfei 已提交
95
  using OpCreator = std::function<OperatorBase*()>;
96 97 98 99

 public:
  template <typename OpType, typename ProtoMakerType>
  static void RegisterOp(const std::string& op_type) {
100 101 102
    creators()[op_type] = [] { return new OpType; };
    OpProto& op_proto = protos()[op_type];
    OpAttrChecker& op_checker = op_checkers()[op_type];
103
    ProtoMakerType(&op_proto, &op_checker);
104
    PADDLE_ENFORCE(op_proto.IsInitialized(),
105 106 107
                   "Fail to initialize %s's OpProto !", op_type);
  }

Q
Qiao Longfei 已提交
108
  static OperatorBase* CreateOp(const OpDesc& op_desc) {
109
    std::string op_type = op_desc.type();
Q
Qiao Longfei 已提交
110 111
    OperatorBase* op = creators().at(op_type)();
    op->desc_ = op_desc;
112 113 114 115 116 117 118
    op->inputs_.reserve((size_t)op_desc.inputs_size());
    std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
              std::back_inserter(op->inputs_));
    op->outputs_.reserve((size_t)op_desc.outputs_size());
    std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
              std::back_inserter(op->outputs_));
    for (auto& attr : op_desc.attrs()) {
Q
Qiao Longfei 已提交
119
      op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
120
    }
Q
Qiao Longfei 已提交
121
    op_checkers().at(op_type).Check(op->attrs_);
122 123 124 125
    return op;
  }

 private:
126 127 128 129
  static std::unordered_map<std::string, OpCreator>& creators() {
    static std::unordered_map<std::string, OpCreator> creators_;
    return creators_;
  }
130

131 132 133 134 135 136 137 138 139 140
  static std::unordered_map<std::string, OpProto>& protos() {
    static std::unordered_map<std::string, OpProto> protos_;
    return protos_;
  };

  static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
    static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
    return op_checkers_;
  };
};
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
 public:
  OpRegisterHelper(std::string op_type) {
    OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
  }
};

#define REGISTER_OP(__op_class, __op_maker_class, __op_type)         \
  class __op_class##Register {                                       \
   private:                                                          \
    const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
  };                                                                 \
  const OpRegisterHelper<__op_class, __op_maker_class>               \
      __op_class##Register::reg(#__op_type);

}  // namespace framework
}  // namespace paddle