diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 9eb07acdff1d00dd926f1cee9c24f9f151006d7e..27132eaa0b3b0666fc042faf052dac2e169ba9e7 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -43,6 +43,10 @@ template <> AttrType AttrTypeID>() { return STRINGS; } +template <> +AttrType AttrTypeID>>() { + return INT_PAIRS; +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { @@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } + case paddle::framework::AttrType::INT_PAIRS: { + std::vector> val(attr_desc.int_pairs_size()); + for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { + val[i].first = attr_desc.int_pairs(i).first(); + val[i].second = attr_desc.int_pairs(i).second(); + } + return val; + } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 08b47cabd4c2225c50022bd35734dcc2663324d6..071879a9d453377ccc2e9e71b62e8568a7ef1c9b 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -28,7 +28,8 @@ namespace paddle { namespace framework { typedef boost::variant, - std::vector, std::vector> + std::vector, std::vector, + std::vector>> Attribute; typedef std::unordered_map AttributeMap; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ae44a1ffd45dacdc44a72edc630e771e7a2f2990..368136a9729dd2c745cc71bc391031e0a390fc87 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,8 +22,14 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; + INT_PAIRS = 6; } +message IntPair { + required int32 first = 1; + required int32 second = 2; +}; + // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -37,6 +43,7 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; + repeated IntPair int_pairs = 9; }; message Var { diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index e7e932f6fe46f2db8de37cecfd9875b310bdcbe5..0349407a851ebb48f69d7daef7a318cf348aad5d 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -94,9 +94,14 @@ class OpDescCreationMethod(object): new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) + elif attr.type == framework_pb2.INT_PAIRS: + for p in user_defined_attr: + pair = new_attr.pairs.add() + pair.first = p[0] + pair.second = p[1] else: raise NotImplementedError("Not support attribute type " + - attr.type) + str(attr.type)) return op_desc