提交 fad48fa6 编写于 作者: D dangqingqing

Add bool type for attr.

上级 6e9337e3
......@@ -19,6 +19,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <>
AttrType AttrTypeID<bool>() {
return BOOL;
}
template <>
AttrType AttrTypeID<int>() {
return INT;
......@@ -32,6 +36,10 @@ AttrType AttrTypeID<std::string>() {
return STRING;
}
template <>
AttrType AttrTypeID<std::vector<bool>>() {
return BOOLS;
}
template <>
AttrType AttrTypeID<std::vector<int>>() {
return INTS;
}
......@@ -50,6 +58,9 @@ AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::BOOL: {
return attr_desc.b();
}
case paddle::framework::AttrType::INT: {
return attr_desc.i();
}
......@@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
case paddle::framework::AttrType::STRING: {
return attr_desc.s();
}
case paddle::framework::AttrType::BOOLS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
return val;
}
case paddle::framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
......
......@@ -27,8 +27,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>,
typedef boost::variant<boost::blank, bool, int, float, std::string,
std::vector<bool>, std::vector<int>, std::vector<float>,
std::vector<std::string>,
std::vector<std::pair<int, int>>>
Attribute;
......
......@@ -23,6 +23,8 @@ enum AttrType {
FLOATS = 4;
STRINGS = 5;
INT_PAIRS = 6;
BOOL = 7;
BOOLS = 8;
}
message IntPair {
......@@ -44,6 +46,8 @@ message OpDesc {
repeated float floats = 7;
repeated string strings = 8;
repeated IntPair int_pairs = 9;
optional bool b = 10;
repeated bool bools = 6;
};
message Var {
......
......@@ -89,12 +89,16 @@ class OpDescCreationMethod(object):
new_attr.f = user_defined_attr
elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOL:
new_attr.b = user_defined_attr
elif attr.type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr.type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLS:
new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr:
pair = new_attr.int_pairs.add()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册