提交 fad48fa6 编写于 作者: D dangqingqing

Add bool type for attr.

上级 6e9337e3
...@@ -19,6 +19,10 @@ limitations under the License. */ ...@@ -19,6 +19,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <>
AttrType AttrTypeID<bool>() {
return BOOL;
}
template <> template <>
AttrType AttrTypeID<int>() { AttrType AttrTypeID<int>() {
return INT; return INT;
...@@ -32,6 +36,10 @@ AttrType AttrTypeID<std::string>() { ...@@ -32,6 +36,10 @@ AttrType AttrTypeID<std::string>() {
return STRING; return STRING;
} }
template <> template <>
AttrType AttrTypeID<std::vector<bool>>() {
return BOOLS;
}
template <>
AttrType AttrTypeID<std::vector<int>>() { AttrType AttrTypeID<std::vector<int>>() {
return INTS; return INTS;
} }
...@@ -50,6 +58,9 @@ AttrType AttrTypeID<std::vector<std::pair<int, int>>>() { ...@@ -50,6 +58,9 @@ AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case paddle::framework::AttrType::BOOL: {
return attr_desc.b();
}
case paddle::framework::AttrType::INT: { case paddle::framework::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
} }
...@@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
case paddle::framework::AttrType::STRING: { case paddle::framework::AttrType::STRING: {
return attr_desc.s(); return attr_desc.s();
} }
case paddle::framework::AttrType::BOOLS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
return val;
}
case paddle::framework::AttrType::INTS: { case paddle::framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size()); std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) { for (int i = 0; i < attr_desc.ints_size(); ++i) {
......
...@@ -27,8 +27,9 @@ limitations under the License. */ ...@@ -27,8 +27,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>, typedef boost::variant<boost::blank, bool, int, float, std::string,
std::vector<float>, std::vector<std::string>, std::vector<bool>, std::vector<int>, std::vector<float>,
std::vector<std::string>,
std::vector<std::pair<int, int>>> std::vector<std::pair<int, int>>>
Attribute; Attribute;
......
...@@ -23,6 +23,8 @@ enum AttrType { ...@@ -23,6 +23,8 @@ enum AttrType {
FLOATS = 4; FLOATS = 4;
STRINGS = 5; STRINGS = 5;
INT_PAIRS = 6; INT_PAIRS = 6;
BOOL = 7;
BOOLS = 8;
} }
message IntPair { message IntPair {
...@@ -44,6 +46,8 @@ message OpDesc { ...@@ -44,6 +46,8 @@ message OpDesc {
repeated float floats = 7; repeated float floats = 7;
repeated string strings = 8; repeated string strings = 8;
repeated IntPair int_pairs = 9; repeated IntPair int_pairs = 9;
optional bool b = 10;
repeated bool bools = 6;
}; };
message Var { message Var {
......
...@@ -89,12 +89,16 @@ class OpDescCreationMethod(object): ...@@ -89,12 +89,16 @@ class OpDescCreationMethod(object):
new_attr.f = user_defined_attr new_attr.f = user_defined_attr
elif attr.type == framework_pb2.STRING: elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOL:
new_attr.b = user_defined_attr
elif attr.type == framework_pb2.INTS: elif attr.type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr) new_attr.ints.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOATS: elif attr.type == framework_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr) new_attr.floats.extend(user_defined_attr)
elif attr.type == framework_pb2.STRINGS: elif attr.type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLS:
new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS: elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr: for p in user_defined_attr:
pair = new_attr.int_pairs.add() pair = new_attr.int_pairs.add()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册