提交 04e604b7 编写于 作者: Y Yu Yang 提交者: dongzhihong

Unify Map in OpDescBind

上级 9b54ad18
...@@ -112,6 +112,30 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() ...@@ -112,6 +112,30 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
return attrs_; return attrs_;
} }
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void OpDescBind::Sync() { void OpDescBind::Sync() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->op_desc_.mutable_inputs()->Clear();
...@@ -134,7 +158,8 @@ void OpDescBind::Sync() { ...@@ -134,7 +158,8 @@ void OpDescBind::Sync() {
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1)); static_cast<framework::AttrType>(attr.second.which() - 1));
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
} }
need_update_ = false; need_update_ = false;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/var_desc.h" #include "paddle/framework/var_desc.h"
namespace paddle { namespace paddle {
...@@ -61,48 +62,22 @@ class OpDescBind { ...@@ -61,48 +62,22 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++ // Only be used in C++
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map); void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
// Only be used in C++ // Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const; const AttributeMap &GetAttrMap() const;
private: private:
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void Sync(); void Sync();
OpDesc op_desc_; OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_; VariableNameMap inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_; VariableNameMap outputs_;
std::unordered_map<std::string, Attribute> attrs_; AttributeMap attrs_;
// need_update_ indicate there some local changes not be synchronized. If // need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true. // local changes should be synchronized, need_update_ should be set to true.
......
...@@ -185,7 +185,7 @@ inline void throw_on_error(T e) { ...@@ -185,7 +185,7 @@ inline void throw_on_error(T e) {
std::make_exception_ptr( \ std::make_exception_ptr( \
std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} while (0) } while (false)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
do { \ do { \
...@@ -195,7 +195,7 @@ inline void throw_on_error(T e) { ...@@ -195,7 +195,7 @@ inline void throw_on_error(T e) {
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} \ } \
} while (0) } while (false)
/* /*
* Some enforce helpers here, usage: * Some enforce helpers here, usage:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册