diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 20bdc7830f32564448a69e9cd76c02585b7a1aca..344c001a69b53c82967ee983783892a514c2490b 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto all_ops = blocks_[block_id]->AllOps(); for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { int sub_block_id = o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); + } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) { + std::vector sub_block_ids = + o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name); + std::vector block_descs; + for (int block_id : sub_block_ids) { + block_descs.push_back(MutableBlock(block_id)); + } + op->SetBlocksAttr(attr_name, block_descs); } } } @@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { desc_ = desc; - for (auto &block_desc : *desc_.mutable_blocks()) { - blocks_.emplace_back(new BlockDesc(this, &block_desc)); - } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); - } - } - } - } + InitFromProto(); } ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); + InitFromProto(); +} + +void ProgramDesc::InitFromProto() { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { if (attr.type() == proto::AttrType::BLOCK) { size_t blk_idx = attr.block_idx(); op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } else if (attr.type() == proto::AttrType::BLOCKS) { + auto blks_idx = attr.blocks_idx(); + std::vector block_descs; + for (int blk_idx : blks_idx) { + block_descs.push_back(this->MutableBlock(blk_idx)); + } + op->SetBlocksAttr(attr.name(), block_descs); } } } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 65fa0a0cfd5ba6d9b8765cee1309e118cb74348a..f3afc85eb924e4b03b7597e043ffd4e267adc977 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -76,6 +76,8 @@ class ProgramDesc { void SetFetchHolderName(const std::string &fetch_holder_name); private: + void InitFromProto(); + proto::ProgramDesc desc_; std::vector> blocks_;