From fad48fa6b1865d353cc277845db5195f79df3be7 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 19 Sep 2017 15:31:58 +0800 Subject: [PATCH] Add bool type for attr. --- paddle/framework/attribute.cc | 18 ++++++++++++++++++ paddle/framework/attribute.h | 5 +++-- paddle/framework/framework.proto | 4 ++++ python/paddle/v2/framework/op.py | 4 ++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 27132eaa0..e18d1add9 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -19,6 +19,10 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +AttrType AttrTypeID() { + return BOOL; +} template <> AttrType AttrTypeID() { return INT; @@ -32,6 +36,10 @@ AttrType AttrTypeID() { return STRING; } template <> +AttrType AttrTypeID>() { + return BOOLS; +} +template <> AttrType AttrTypeID>() { return INTS; } @@ -50,6 +58,9 @@ AttrType AttrTypeID>>() { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { + case paddle::framework::AttrType::BOOL: { + return attr_desc.b(); + } case paddle::framework::AttrType::INT: { return attr_desc.i(); } @@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { case paddle::framework::AttrType::STRING: { return attr_desc.s(); } + case paddle::framework::AttrType::BOOLS: { + std::vector val(attr_desc.bools_size()); + for (int i = 0; i < attr_desc.bools_size(); ++i) { + val[i] = attr_desc.bools(i); + } + return val; + } case paddle::framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 2b788a76c..3232a9003 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,8 +27,9 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, - std::vector, std::vector, +typedef boost::variant, std::vector, std::vector, + std::vector, std::vector>> Attribute; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index dfcb5fb62..ec7b750d8 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -23,6 +23,8 @@ enum AttrType { FLOATS = 4; STRINGS = 5; INT_PAIRS = 6; + BOOL = 7; + BOOLS = 8; } message IntPair { @@ -44,6 +46,8 @@ message OpDesc { repeated float floats = 7; repeated string strings = 8; repeated IntPair int_pairs = 9; + optional bool b = 10; + repeated bool bools = 6; }; message Var { diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 6cca41e43..93dfbc5d3 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -89,12 +89,16 @@ class OpDescCreationMethod(object): new_attr.f = user_defined_attr elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr + elif attr.type == framework_pb2.BOOL: + new_attr.b = user_defined_attr elif attr.type == framework_pb2.INTS: new_attr.ints.extend(user_defined_attr) elif attr.type == framework_pb2.FLOATS: new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) + elif attr.type == framework_pb2.BOOLS: + new_attr.bools.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: pair = new_attr.int_pairs.add() -- GitLab