提交 4b948abb 编写于 作者: F fengjiayi

Update Attribute to make it compatible with BLOCK

上级 c869b8e2
......@@ -47,40 +47,44 @@ template <>
AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
return INT_PAIRS;
}
template <>
AttrType AttrTypeID<BlockDesc>() {
return BLOCK;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: {
case framework::AttrType::INT: {
return attr_desc.i();
}
case paddle::framework::AttrType::FLOAT: {
case framework::AttrType::FLOAT: {
return attr_desc.f();
}
case paddle::framework::AttrType::STRING: {
case framework::AttrType::STRING: {
return attr_desc.s();
}
case paddle::framework::AttrType::INTS: {
case framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
case paddle::framework::AttrType::FLOATS: {
case framework::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
return val;
}
case paddle::framework::AttrType::STRINGS: {
case framework::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
return val;
}
case paddle::framework::AttrType::INT_PAIRS: {
case framework::AttrType::INT_PAIRS: {
std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
val[i].first = attr_desc.int_pairs(i).first();
......@@ -88,6 +92,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
}
return val;
}
case framework::AttrType::BLOCK: {
return g_program_desc.blocks(attr_desc.block_idx());
}
}
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank();
......
......@@ -29,11 +29,13 @@ namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>,
std::vector<std::pair<int, int>>>
std::vector<std::pair<int, int>>, BlockDesc>
Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap;
static ProgramDesc g_program_desc;
template <typename T>
AttrType AttrTypeID();
......
......@@ -45,7 +45,7 @@ message OpDesc {
repeated float floats = 7;
repeated string strings = 8;
repeated IntPair int_pairs = 9;
optional int32 block = 10;
optional int32 block_idx = 10;
};
message Var {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册