diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 27132eaa0b3b0666fc042faf052dac2e169ba9e7..e18d1add9650290d81f94a288cea91b374fbf475 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -19,6 +19,10 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +AttrType AttrTypeID() { + return BOOL; +} template <> AttrType AttrTypeID() { return INT; @@ -32,6 +36,10 @@ AttrType AttrTypeID() { return STRING; } template <> +AttrType AttrTypeID>() { + return BOOLS; +} +template <> AttrType AttrTypeID>() { return INTS; } @@ -50,6 +58,9 @@ AttrType AttrTypeID>>() { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { + case paddle::framework::AttrType::BOOL: { + return attr_desc.b(); + } case paddle::framework::AttrType::INT: { return attr_desc.i(); } @@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { case paddle::framework::AttrType::STRING: { return attr_desc.s(); } + case paddle::framework::AttrType::BOOLS: { + std::vector val(attr_desc.bools_size()); + for (int i = 0; i < attr_desc.bools_size(); ++i) { + val[i] = attr_desc.bools(i); + } + return val; + } case paddle::framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 2b788a76cafe198abb9aed8ba842e37cc6ff73a6..3232a9003ebe2ad79035d1ac3ad93582622d2f41 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,8 +27,9 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, - std::vector, std::vector, +typedef boost::variant, std::vector, std::vector, + std::vector, std::vector>> Attribute; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index dfcb5fb6210a08f35193b83e3b5f7cee92f618d7..ec7b750d812ba2330cddaa1f7b1176ca2c882d79 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -23,6 +23,8 @@ enum AttrType { FLOATS = 4; STRINGS = 5; INT_PAIRS = 6; + BOOL = 7; + BOOLS = 8; } message IntPair { @@ -44,6 +46,8 @@ message OpDesc { repeated float floats = 7; repeated string strings = 8; repeated IntPair int_pairs = 9; + optional bool b = 10; + repeated bool bools = 6; }; message Var { diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 6cca41e43b38b8cccb65ff9b347ef226dddecd4d..93dfbc5d300494aa0830cd43ea91e359a2b54ae3 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -89,12 +89,16 @@ class OpDescCreationMethod(object): new_attr.f = user_defined_attr elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr + elif attr.type == framework_pb2.BOOL: + new_attr.b = user_defined_attr elif attr.type == framework_pb2.INTS: new_attr.ints.extend(user_defined_attr) elif attr.type == framework_pb2.FLOATS: new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) + elif attr.type == framework_pb2.BOOLS: + new_attr.bools.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: pair = new_attr.int_pairs.add()