提交 7f6b5044 编写于 作者: Y Yu Yang

Make OpInfoMap as a class

* Add Get/Has methods to OpInfoMap
* Add PADDLE_ENFORCE for OpInfo to get field.
上级 59b3df31
...@@ -24,9 +24,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, ...@@ -24,9 +24,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto* proto = OpInfoMap().at(src_op->Type()).proto_; auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue; if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
...@@ -40,14 +40,8 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, ...@@ -40,14 +40,8 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = OpInfoMap().find(op->Type()); auto& info = OpInfoMap::Instance().Get(op->Type());
PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", PADDLE_ENFORCE(info.HasGradientOp());
op->Type());
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->Type());
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->Type());
VariableNameMap inputs; VariableNameMap inputs;
VariableNameMap outputs; VariableNameMap outputs;
...@@ -56,10 +50,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -56,10 +50,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, OpArgType::IN, true, &outputs); // IG TransOpArg(op, OpArgType::IN, true, &outputs); // IG
it = OpInfoMap().find(grad_op_type); auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
grad_op_type);
return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
} }
} // namespace framework } // namespace framework
......
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static std::unordered_map<std::string, const paddle::framework::OpInfo>* static OpInfoMap* g_op_info_map = nullptr;
g_op_info_map = nullptr;
std::unordered_map<std::string, const paddle::framework::OpInfo>& OpInfoMap() { OpInfoMap& OpInfoMap::Instance() {
if (g_op_info_map == nullptr) { if (g_op_info_map == nullptr) {
g_op_info_map = g_op_info_map = new OpInfoMap();
new std::unordered_map<std::string, const paddle::framework::OpInfo>();
} }
return *g_op_info_map; return *g_op_info_map;
} }
......
...@@ -34,9 +34,68 @@ struct OpInfo { ...@@ -34,9 +34,68 @@ struct OpInfo {
std::string grad_op_type_; std::string grad_op_type_;
OpProto* proto_; OpProto* proto_;
OpAttrChecker* checker_; OpAttrChecker* checker_;
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
}
const OpProto& Proto() const {
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
PADDLE_ENFORCE(proto_->IsInitialized(),
"Operator Proto must be initialized in op info");
return *proto_;
}
const OpAttrChecker& Checker() const {
PADDLE_ENFORCE_NOT_NULL(checker_,
"Operator Checker has not been registered");
return *checker_;
}
const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator Creator has not been registered");
return creator_;
}
bool HasGradientOp() const { return !grad_op_type_.empty(); }
}; };
extern std::unordered_map<std::string, const OpInfo>& OpInfoMap(); class OpInfoMap {
public:
static OpInfoMap& Instance();
OpInfoMap(const OpInfoMap& o) = delete;
OpInfoMap(OpInfoMap&& o) = delete;
OpInfoMap& operator=(const OpInfoMap& o) = delete;
OpInfoMap& operator=(OpInfoMap&& o) = delete;
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
void Insert(const std::string& type, const OpInfo& info) {
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type);
map_.insert({type, info});
}
const OpInfo& Get(const std::string& type) const {
auto it = map_.find(type);
PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type);
return it->second;
}
template <typename Callback>
void IterAllInfo(Callback callback) {
for (auto& it : map_) {
callback(it.first, it.second);
}
}
private:
OpInfoMap() = default;
std::unordered_map<std::string, const OpInfo> map_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -22,11 +22,9 @@ namespace framework { ...@@ -22,11 +22,9 @@ namespace framework {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs) { const VariableNameMap& outputs, AttributeMap attrs) {
auto it = OpInfoMap().find(type); auto& info = OpInfoMap::Instance().Get(type);
PADDLE_ENFORCE(it != OpInfoMap().end(), info.Checker().Check(attrs);
"Operator '%s' has not been registered.", type); auto op = info.Creator()(type, inputs, outputs, attrs);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op); return std::unique_ptr<OperatorBase>(op);
} }
......
...@@ -35,7 +35,7 @@ class OpRegistry { ...@@ -35,7 +35,7 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type, static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
PADDLE_ENFORCE(OpInfoMap().count(op_type) == 0, PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type); "'%s' is registered more than once.", op_type);
OpInfo op_info; OpInfo op_info;
op_info.creator_ = []( op_info.creator_ = [](
...@@ -59,7 +59,7 @@ class OpRegistry { ...@@ -59,7 +59,7 @@ class OpRegistry {
op_info.proto_ = nullptr; op_info.proto_ = nullptr;
op_info.checker_ = nullptr; op_info.checker_ = nullptr;
} }
OpInfoMap().insert(std::make_pair(op_type, op_info)); OpInfoMap::Instance().Insert(op_type, op_info);
// register gradient op // register gradient op
if (!grad_op_type.empty()) { if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, ""); RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
......
...@@ -141,18 +141,10 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -141,18 +141,10 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
} }
return ret_val; return ret_val;
} }
auto it = OpInfoMap().find(type_); auto& info = OpInfoMap::Instance().Get(Type());
PADDLE_ENFORCE(
it != OpInfoMap().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
PADDLE_ENFORCE(
it->second.proto_ != nullptr,
"Operator %s has no OpProto, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs // get all OpProto::Var for outputs
for (auto& o : it->second.proto_->outputs()) { for (auto& o : info.Proto().outputs()) {
// ignore all intermediate output // ignore all intermediate output
if (o.intermediate()) continue; if (o.intermediate()) continue;
auto out = outputs_.find(o.name()); auto out = outputs_.find(o.name());
......
...@@ -138,19 +138,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -138,19 +138,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &op_info_map = OpInfoMap();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
const OpProto *proto = it->second.proto_; OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type,
if (proto == nullptr) { const OpInfo &info) {
continue; if (!info.HasOpProtoAndChecker()) return;
}
PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(proto->SerializeToString(&str), PADDLE_ENFORCE(info.Proto().SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.emplace_back(str);
} });
return ret_values; return ret_values;
}); });
m.def_submodule( m.def_submodule(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册