提交 7d33447d 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #4276 from Canpio/add_program_proto

Add `BlockDesc` and `ProgramDesc` to framework.proto
...@@ -19,6 +19,15 @@ limitations under the License. */ ...@@ -19,6 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static ProgramDesc* g_program_desc = nullptr;
ProgramDesc& GetProgramDesc() {
if (g_program_desc == nullptr) {
g_program_desc = new ProgramDesc();
}
return *g_program_desc;
}
template <> template <>
AttrType AttrTypeID<int>() { AttrType AttrTypeID<int>() {
return INT; return INT;
...@@ -47,40 +56,44 @@ template <> ...@@ -47,40 +56,44 @@ template <>
AttrType AttrTypeID<std::vector<std::pair<int, int>>>() { AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
return INT_PAIRS; return INT_PAIRS;
} }
template <>
AttrType AttrTypeID<BlockDesc>() {
return BLOCK;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: { case framework::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
} }
case paddle::framework::AttrType::FLOAT: { case framework::AttrType::FLOAT: {
return attr_desc.f(); return attr_desc.f();
} }
case paddle::framework::AttrType::STRING: { case framework::AttrType::STRING: {
return attr_desc.s(); return attr_desc.s();
} }
case paddle::framework::AttrType::INTS: { case framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size()); std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) { for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i); val[i] = attr_desc.ints(i);
} }
return val; return val;
} }
case paddle::framework::AttrType::FLOATS: { case framework::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size()); std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) { for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i); val[i] = attr_desc.floats(i);
} }
return val; return val;
} }
case paddle::framework::AttrType::STRINGS: { case framework::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size()); std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) { for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i); val[i] = attr_desc.strings(i);
} }
return val; return val;
} }
case paddle::framework::AttrType::INT_PAIRS: { case framework::AttrType::INT_PAIRS: {
std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size()); std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
val[i].first = attr_desc.int_pairs(i).first(); val[i].first = attr_desc.int_pairs(i).first();
...@@ -88,6 +101,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -88,6 +101,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
} }
return val; return val;
} }
case framework::AttrType::BLOCK: {
return GetProgramDesc().mutable_blocks(attr_desc.block_idx());
}
} }
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank(); return boost::blank();
......
...@@ -29,11 +29,13 @@ namespace framework { ...@@ -29,11 +29,13 @@ namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>, typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, std::vector<float>, std::vector<std::string>,
std::vector<std::pair<int, int>>> std::vector<std::pair<int, int>>, BlockDesc*>
Attribute; Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap; typedef std::unordered_map<std::string, Attribute> AttributeMap;
ProgramDesc& GetProgramDesc();
template <typename T> template <typename T>
AttrType AttrTypeID(); AttrType AttrTypeID();
......
...@@ -23,6 +23,7 @@ enum AttrType { ...@@ -23,6 +23,7 @@ enum AttrType {
FLOATS = 4; FLOATS = 4;
STRINGS = 5; STRINGS = 5;
INT_PAIRS = 6; INT_PAIRS = 6;
BLOCK = 7;
} }
message IntPair { message IntPair {
...@@ -44,6 +45,7 @@ message OpDesc { ...@@ -44,6 +45,7 @@ message OpDesc {
repeated float floats = 7; repeated float floats = 7;
repeated string strings = 8; repeated string strings = 8;
repeated IntPair int_pairs = 9; repeated IntPair int_pairs = 9;
optional int32 block_idx = 10;
}; };
message Var { message Var {
...@@ -108,3 +110,12 @@ message VarDesc { ...@@ -108,3 +110,12 @@ message VarDesc {
required string name = 1; required string name = 1;
optional LoDTensorDesc lod_tensor = 2; optional LoDTensorDesc lod_tensor = 2;
} }
message BlockDesc {
required int32 idx = 1;
required int32 parent_idx = 2;
repeated VarDesc vars = 3;
repeated OpDesc ops = 4;
}
message ProgramDesc { repeated BlockDesc blocks = 1; }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册