提交 04e604b7 编写于 作者: Y Yu Yang 提交者: dongzhihong

Unify Map in OpDescBind

上级 9b54ad18
......@@ -112,6 +112,30 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
return attrs_;
}
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void OpDescBind::Sync() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
......@@ -134,7 +158,8 @@ void OpDescBind::Sync() {
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
}
need_update_ = false;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/var_desc.h"
namespace paddle {
......@@ -61,48 +62,22 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
// Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
const AttributeMap &GetAttrMap() const;
private:
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void Sync();
OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_;
std::unordered_map<std::string, Attribute> attrs_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
......
......@@ -185,7 +185,7 @@ inline void throw_on_error(T e) {
std::make_exception_ptr( \
std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)
} while (false)
#define PADDLE_ENFORCE(...) \
do { \
......@@ -195,7 +195,7 @@ inline void throw_on_error(T e) {
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0)
} while (false)
/*
* Some enforce helpers here, usage:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册