diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 27132eaa0b3b0666fc042faf052dac2e169ba9e7..534c0d8d686cf30688078005975dec1bde89bf60 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -47,40 +47,44 @@ template <> AttrType AttrTypeID>>() { return INT_PAIRS; } +template <> +AttrType AttrTypeID() { + return BLOCK; +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { - case paddle::framework::AttrType::INT: { + case framework::AttrType::INT: { return attr_desc.i(); } - case paddle::framework::AttrType::FLOAT: { + case framework::AttrType::FLOAT: { return attr_desc.f(); } - case paddle::framework::AttrType::STRING: { + case framework::AttrType::STRING: { return attr_desc.s(); } - case paddle::framework::AttrType::INTS: { + case framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { val[i] = attr_desc.ints(i); } return val; } - case paddle::framework::AttrType::FLOATS: { + case framework::AttrType::FLOATS: { std::vector val(attr_desc.floats_size()); for (int i = 0; i < attr_desc.floats_size(); ++i) { val[i] = attr_desc.floats(i); } return val; } - case paddle::framework::AttrType::STRINGS: { + case framework::AttrType::STRINGS: { std::vector val(attr_desc.strings_size()); for (int i = 0; i < attr_desc.strings_size(); ++i) { val[i] = attr_desc.strings(i); } return val; } - case paddle::framework::AttrType::INT_PAIRS: { + case framework::AttrType::INT_PAIRS: { std::vector> val(attr_desc.int_pairs_size()); for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { val[i].first = attr_desc.int_pairs(i).first(); @@ -88,6 +92,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } + case framework::AttrType::BLOCK: { + return g_program_desc.blocks(attr_desc.block_idx()); + } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 2b788a76cafe198abb9aed8ba842e37cc6ff73a6..f18123bac7c3ca21df05cabb327380b6f27cde0b 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -29,11 +29,13 @@ namespace framework { typedef boost::variant, std::vector, std::vector, - std::vector>> + std::vector>, BlockDesc> Attribute; typedef std::unordered_map AttributeMap; +static ProgramDesc g_program_desc; + template AttrType AttrTypeID(); diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index dcbc9ec4076dbba18b1c38606bbaf30ebcd86b51..89a49f69062486ace67154f52450e7449a948851 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -45,7 +45,7 @@ message OpDesc { repeated float floats = 7; repeated string strings = 8; repeated IntPair int_pairs = 9; - optional int32 block = 10; + optional int32 block_idx = 10; }; message Var {