op_registry.h 5.3 KB
Newer Older
1 2 3 4 5
#pragma once

#include "paddle/framework/attr_checker.h"

//#include "paddle/framework/op_base.h"
6
#include <algorithm>
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"

namespace paddle {
namespace framework {

//==================For test================//
class OpBase {
 public:
  std::vector<std::string> inputs_;
  std::vector<std::string> outputs_;
  AttributeMap attr_map_;

  virtual std::string Run() const = 0;
  virtual ~OpBase() {}
};
//=========================================//

// helper class to set attribute type
struct AttrTypeHelper {
  template <typename T>
  static void SetAttrType(AttrProto* attr);

  static Attribute GetAttrValue(const AttrDesc& attr_desc) {
    switch (attr_desc.type()) {
      case paddle::framework::AttrType::INT: {
        return attr_desc.i();
      }
      case paddle::framework::AttrType::FLOAT: {
        return attr_desc.f();
      }
      case paddle::framework::AttrType::STRING: {
        return attr_desc.s();
      }
      case paddle::framework::AttrType::INTS: {
        std::vector<int> val(attr_desc.ints_size());
        for (int i = 0; i < attr_desc.ints_size(); ++i) {
          val[i] = attr_desc.ints(i);
        }
        return val;
      }
      case paddle::framework::AttrType::FLOATS: {
        std::vector<float> val(attr_desc.floats_size());
        for (int i = 0; i < attr_desc.floats_size(); ++i) {
          val[i] = attr_desc.floats(i);
        }
        return val;
      }
      case paddle::framework::AttrType::STRINGS: {
        std::vector<std::string> val(attr_desc.strings_size());
        for (int i = 0; i < attr_desc.strings_size(); ++i) {
          val[i] = attr_desc.strings(i);
        }
        return val;
      }
    }
    PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
    return boost::blank();
  }
};

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
 public:
  OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : proto_(proto), op_checker_(op_checker) {}

 protected:
  void AddInput(const std::string& name, const std::string& comment) {
    auto input = proto_->mutable_inputs()->Add();
77 78
    *input->mutable_name() = name;
    *input->mutable_comment() = comment;
79 80 81 82
  }

  void AddOutput(const std::string& name, const std::string& comment) {
    auto output = proto_->mutable_outputs()->Add();
83 84
    *output->mutable_name() = name;
    *output->mutable_comment() = comment;
85 86 87 88 89 90
  }

  template <typename T>
  TypedAttrChecker<T>& AddAttr(const std::string& name,
                               const std::string& comment) {
    auto attr = proto_->mutable_attrs()->Add();
91 92
    *attr->mutable_name() = name;
    *attr->mutable_comment() = comment;
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    AttrTypeHelper::SetAttrType<T>(attr);
    return op_checker_->AddAttrChecker<T>(name);
  }

  void AddType(const std::string& op_type) { proto_->set_type(op_type); }

  void AddComment(const std::string& comment) {
    *(proto_->mutable_comment()) = comment;
  }

  OpProto* proto_;
  OpAttrChecker* op_checker_;
};

class OpRegistry {
108
  using OpCreator = std::function<OpBase*()>;
109 110 111 112

 public:
  template <typename OpType, typename ProtoMakerType>
  static void RegisterOp(const std::string& op_type) {
113 114 115
    creators()[op_type] = [] { return new OpType; };
    OpProto& op_proto = protos()[op_type];
    OpAttrChecker& op_checker = op_checkers()[op_type];
116
    ProtoMakerType(&op_proto, &op_checker);
117
    PADDLE_ENFORCE(op_proto.IsInitialized(),
118 119 120 121 122
                   "Fail to initialize %s's OpProto !", op_type);
  }

  static OpBase* CreateOp(const OpDesc& op_desc) {
    std::string op_type = op_desc.type();
123 124 125 126 127 128 129 130 131
    OpBase* op = creators().at(op_type)();
    op->inputs_.reserve((size_t)op_desc.inputs_size());
    std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
              std::back_inserter(op->inputs_));
    op->outputs_.reserve((size_t)op_desc.outputs_size());
    std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
              std::back_inserter(op->outputs_));
    for (auto& attr : op_desc.attrs()) {
      op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
132
    }
133
    op_checkers().at(op_type).Check(op->attr_map_);
134 135 136 137
    return op;
  }

 private:
138 139 140 141
  static std::unordered_map<std::string, OpCreator>& creators() {
    static std::unordered_map<std::string, OpCreator> creators_;
    return creators_;
  }
142

143 144 145 146 147 148 149 150 151 152
  static std::unordered_map<std::string, OpProto>& protos() {
    static std::unordered_map<std::string, OpProto> protos_;
    return protos_;
  };

  static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
    static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
    return op_checkers_;
  };
};
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
 public:
  OpRegisterHelper(std::string op_type) {
    OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
  }
};

#define REGISTER_OP(__op_class, __op_maker_class, __op_type)         \
  class __op_class##Register {                                       \
   private:                                                          \
    const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
  };                                                                 \
  const OpRegisterHelper<__op_class, __op_maker_class>               \
      __op_class##Register::reg(#__op_type);

}  // namespace framework
}  // namespace paddle