提交 de350987 编写于 作者: Y Yu Yang

Fix CI and follow comment

上级 d72a140e
...@@ -44,8 +44,11 @@ template <template <class...> class V, class... Ts> ...@@ -44,8 +44,11 @@ template <template <class...> class V, class... Ts>
struct variant_caster<V<Ts...>> { struct variant_caster<V<Ts...>> {
using Type = V<Ts...>; using Type = V<Ts...>;
template <class T> template <typename T>
bool try_load(handle src, bool convert) { typename std::enable_if<
!std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
auto caster = make_caster<T>(); auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) { if (!load_success_ && caster.load(src, convert)) {
load_success_ = true; load_success_ = true;
...@@ -55,6 +58,13 @@ struct variant_caster<V<Ts...>> { ...@@ -55,6 +58,13 @@ struct variant_caster<V<Ts...>> {
return false; return false;
} }
template <typename T>
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
return false;
}
bool load(handle src, bool convert) { bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...}; auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused); (void)(unused);
...@@ -210,6 +220,45 @@ public: ...@@ -210,6 +220,45 @@ public:
std::string DebugString() { return this->Proto()->DebugString(); } std::string DebugString() { return this->Proto()->DebugString(); }
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
framework::AttrType GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<framework::AttrType>(it->second.which() - 1);
}
std::vector<std::string> AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}
void SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v;
need_update_ = true;
}
void SetBlockAttr(const std::string &name, BlockDescBind &block);
Attribute GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second;
}
int GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx();
}
private:
struct SetAttrDescVisitor : public boost::static_visitor<void> { struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_; mutable OpDesc::Attr *attr_;
...@@ -265,49 +314,13 @@ public: ...@@ -265,49 +314,13 @@ public:
} }
} }
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
framework::AttrType GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<framework::AttrType>(it->second.which() - 1);
}
std::vector<std::string> AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}
void SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v;
}
void SetBlockAttr(const std::string &name, BlockDescBind &block);
Attribute GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second;
}
int GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx();
}
private:
OpDesc op_desc_; OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_; std::unordered_map<std::string, std::vector<std::string>> inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_; std::unordered_map<std::string, std::vector<std::string>> outputs_;
std::unordered_map<std::string, Attribute> attrs_; std::unordered_map<std::string, Attribute> attrs_;
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false}; bool need_update_{false};
}; };
......
...@@ -23,6 +23,7 @@ class TestOpDesc(unittest.TestCase): ...@@ -23,6 +23,7 @@ class TestOpDesc(unittest.TestCase):
op.set_attr("int_attr", 1) op.set_attr("int_attr", 1)
self.assertEqual(1, op.attr("int_attr")) self.assertEqual(1, op.attr("int_attr"))
self.assertTrue(op.has_attr("int_attr")) self.assertTrue(op.has_attr("int_attr"))
self.assertEqual(core.AttrType.INT, op.attr_type("int_attr"))
op.set_attr("float_attr", -1.32) op.set_attr("float_attr", -1.32)
self.assertAlmostEqual(-1.32, op.attr("float_attr"), delta=1e-4) self.assertAlmostEqual(-1.32, op.attr("float_attr"), delta=1e-4)
...@@ -53,7 +54,6 @@ class TestOpDesc(unittest.TestCase): ...@@ -53,7 +54,6 @@ class TestOpDesc(unittest.TestCase):
op.set_block_attr("block_attr", prog.block(0)) op.set_block_attr("block_attr", prog.block(0))
self.assertEqual(0, op.get_block_attr("block_attr")) self.assertEqual(0, op.get_block_attr("block_attr"))
self.assertEqual(core.AttrType.INT, op.attr_type("int_attr"))
class TestProgramDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册