diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 0c12c55dc09f6aa064066b5c73bc5e985a57343f..33a064890cc62a4e3f04c79bc000951e76fd477c 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -112,6 +112,30 @@ const std::unordered_map &OpDescBind::GetAttrMap() return attrs_; } +struct SetAttrDescVisitor : public boost::static_visitor { + explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} + mutable OpDesc::Attr *attr_; + void operator()(int v) const { attr_->set_i(v); } + void operator()(float v) const { attr_->set_f(v); } + void operator()(const std::string &v) const { attr_->set_s(v); } + void operator()(bool b) const { attr_->set_b(b); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_ints()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_floats()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_strings()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_bools()); + } + void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } +}; + void OpDescBind::Sync() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); @@ -134,7 +158,8 @@ void OpDescBind::Sync() { attr_desc->set_name(attr.first); attr_desc->set_type( static_cast(attr.second.which() - 1)); - boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); + SetAttrDescVisitor visitor(attr_desc); + boost::apply_visitor(visitor, attr.second); } need_update_ = false; diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 0cf7d13971675eb825bcd0c7636896f0862d6ebb..e03b4d067f189f24851e55a979574e6e6984454d 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/attribute.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/var_desc.h" namespace paddle { @@ -61,48 +62,22 @@ class OpDescBind { void SetBlockAttr(const std::string &name, BlockDescBind &block); // Only be used in C++ - void SetAttrMap(const std::unordered_map &attr_map); + void SetAttrMap(const AttributeMap &attr_map); Attribute GetAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; // Only be used in C++ - const std::unordered_map &GetAttrMap() const; + const AttributeMap &GetAttrMap() const; private: - struct SetAttrDescVisitor : public boost::static_visitor { - explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} - mutable OpDesc::Attr *attr_; - void operator()(int v) const { attr_->set_i(v); } - void operator()(float v) const { attr_->set_f(v); } - void operator()(const std::string &v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } - - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_ints()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_floats()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_strings()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_bools()); - } - void operator()(BlockDesc *desc) const { - attr_->set_block_idx(desc->idx()); - } - void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } - }; - void Sync(); OpDesc op_desc_; - std::unordered_map> inputs_; - std::unordered_map> outputs_; - std::unordered_map attrs_; + VariableNameMap inputs_; + VariableNameMap outputs_; + AttributeMap attrs_; // need_update_ indicate there some local changes not be synchronized. If // local changes should be synchronized, need_update_ should be set to true. diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index b523ef03c0053622bfda5b4bf07515c1b480b4af..52bd23039b82fc0290edb640f5d5851ba9abee61 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -185,7 +185,7 @@ inline void throw_on_error(T e) { std::make_exception_ptr( \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ __FILE__, __LINE__); \ - } while (0) + } while (false) #define PADDLE_ENFORCE(...) \ do { \ @@ -195,7 +195,7 @@ inline void throw_on_error(T e) { throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ __FILE__, __LINE__); \ } \ - } while (0) + } while (false) /* * Some enforce helpers here, usage: