From 4b948abbf00d097e3eb2a2121174860f2f35c989 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 20 Sep 2017 21:24:54 -0700 Subject: [PATCH] Update Attribute to make it compatible with BLOCK --- paddle/framework/attribute.cc | 21 ++++++++++++++------- paddle/framework/attribute.h | 4 +++- paddle/framework/framework.proto | 2 +- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 27132eaa0b3..534c0d8d686 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -47,40 +47,44 @@ template <> AttrType AttrTypeID>>() { return INT_PAIRS; } +template <> +AttrType AttrTypeID() { + return BLOCK; +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { - case paddle::framework::AttrType::INT: { + case framework::AttrType::INT: { return attr_desc.i(); } - case paddle::framework::AttrType::FLOAT: { + case framework::AttrType::FLOAT: { return attr_desc.f(); } - case paddle::framework::AttrType::STRING: { + case framework::AttrType::STRING: { return attr_desc.s(); } - case paddle::framework::AttrType::INTS: { + case framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { val[i] = attr_desc.ints(i); } return val; } - case paddle::framework::AttrType::FLOATS: { + case framework::AttrType::FLOATS: { std::vector val(attr_desc.floats_size()); for (int i = 0; i < attr_desc.floats_size(); ++i) { val[i] = attr_desc.floats(i); } return val; } - case paddle::framework::AttrType::STRINGS: { + case framework::AttrType::STRINGS: { std::vector val(attr_desc.strings_size()); for (int i = 0; i < attr_desc.strings_size(); ++i) { val[i] = attr_desc.strings(i); } return val; } - case paddle::framework::AttrType::INT_PAIRS: { + case framework::AttrType::INT_PAIRS: { std::vector> val(attr_desc.int_pairs_size()); for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { val[i].first = attr_desc.int_pairs(i).first(); @@ -88,6 +92,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } + case framework::AttrType::BLOCK: { + return g_program_desc.blocks(attr_desc.block_idx()); + } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 2b788a76caf..f18123bac7c 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -29,11 +29,13 @@ namespace framework { typedef boost::variant, std::vector, std::vector, - std::vector>> + std::vector>, BlockDesc> Attribute; typedef std::unordered_map AttributeMap; +static ProgramDesc g_program_desc; + template AttrType AttrTypeID(); diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index dcbc9ec4076..89a49f69062 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -45,7 +45,7 @@ message OpDesc { repeated float floats = 7; repeated string strings = 8; repeated IntPair int_pairs = 9; - optional int32 block = 10; + optional int32 block_idx = 10; }; message Var { -- GitLab