提交 7002e69b 编写于 作者: W whs 提交者: GitHub

Merge pull request #3878 from wanghaoshuang/fix_attr

Make attribute support for std::vector<std::pair<int, int>>
...@@ -43,6 +43,10 @@ template <> ...@@ -43,6 +43,10 @@ template <>
AttrType AttrTypeID<std::vector<std::string>>() { AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS; return STRINGS;
} }
template <>
AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
return INT_PAIRS;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
...@@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
} }
return val; return val;
} }
case paddle::framework::AttrType::INT_PAIRS: {
std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
val[i].first = attr_desc.int_pairs(i).first();
val[i].second = attr_desc.int_pairs(i).second();
}
return val;
}
} }
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank(); return boost::blank();
......
...@@ -28,7 +28,8 @@ namespace paddle { ...@@ -28,7 +28,8 @@ namespace paddle {
namespace framework { namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>, typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>> std::vector<float>, std::vector<std::string>,
std::vector<std::pair<int, int>>>
Attribute; Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap; typedef std::unordered_map<std::string, Attribute> AttributeMap;
......
...@@ -22,8 +22,14 @@ enum AttrType { ...@@ -22,8 +22,14 @@ enum AttrType {
INTS = 3; INTS = 3;
FLOATS = 4; FLOATS = 4;
STRINGS = 5; STRINGS = 5;
INT_PAIRS = 6;
} }
message IntPair {
required int32 first = 1;
required int32 second = 2;
};
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type. // derived class type.
message OpDesc { message OpDesc {
...@@ -37,6 +43,7 @@ message OpDesc { ...@@ -37,6 +43,7 @@ message OpDesc {
repeated int32 ints = 6; repeated int32 ints = 6;
repeated float floats = 7; repeated float floats = 7;
repeated string strings = 8; repeated string strings = 8;
repeated IntPair int_pairs = 9;
}; };
message Var { message Var {
......
...@@ -94,9 +94,14 @@ class OpDescCreationMethod(object): ...@@ -94,9 +94,14 @@ class OpDescCreationMethod(object):
new_attr.floats.extend(user_defined_attr) new_attr.floats.extend(user_defined_attr)
elif attr.type == framework_pb2.STRINGS: elif attr.type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr:
pair = new_attr.pairs.add()
pair.first = p[0]
pair.second = p[1]
else: else:
raise NotImplementedError("Not support attribute type " + raise NotImplementedError("Not support attribute type " +
attr.type) str(attr.type))
return op_desc return op_desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册