From e687f3f540d3a403ab376f6c533362a6e6c577ff Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 5 Sep 2017 17:37:12 +0800 Subject: [PATCH] Make attribute support for std::vector> --- paddle/framework/attribute.cc | 12 ++++++++++++ paddle/framework/attribute.h | 3 ++- paddle/framework/framework.proto | 7 +++++++ python/paddle/v2/framework/op.py | 7 ++++++- 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 9eb07acdff1..27132eaa0b3 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -43,6 +43,10 @@ template <> AttrType AttrTypeID>() { return STRINGS; } +template <> +AttrType AttrTypeID>>() { + return INT_PAIRS; +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { @@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } + case paddle::framework::AttrType::INT_PAIRS: { + std::vector> val(attr_desc.int_pairs_size()); + for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { + val[i].first = attr_desc.int_pairs(i).first(); + val[i].second = attr_desc.int_pairs(i).second(); + } + return val; + } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 08b47cabd4c..071879a9d45 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -28,7 +28,8 @@ namespace paddle { namespace framework { typedef boost::variant, - std::vector, std::vector> + std::vector, std::vector, + std::vector>> Attribute; typedef std::unordered_map AttributeMap; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ae44a1ffd45..368136a9729 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,8 +22,14 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; + INT_PAIRS = 6; } +message IntPair { + required int32 first = 1; + required int32 second = 2; +}; + // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -37,6 +43,7 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; + repeated IntPair int_pairs = 9; }; message Var { diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index e7e932f6fe4..0349407a851 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -94,9 +94,14 @@ class OpDescCreationMethod(object): new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) + elif attr.type == framework_pb2.INT_PAIRS: + for p in user_defined_attr: + pair = new_attr.pairs.add() + pair.first = p[0] + pair.second = p[1] else: raise NotImplementedError("Not support attribute type " + - attr.type) + str(attr.type)) return op_desc -- GitLab