diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 20bdc7830f32564448a69e9cd76c02585b7a1aca..344c001a69b53c82967ee983783892a514c2490b 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto all_ops = blocks_[block_id]->AllOps(); for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { int sub_block_id = o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); + } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) { + std::vector sub_block_ids = + o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name); + std::vector block_descs; + for (int block_id : sub_block_ids) { + block_descs.push_back(MutableBlock(block_id)); + } + op->SetBlocksAttr(attr_name, block_descs); } } } @@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { desc_ = desc; - for (auto &block_desc : *desc_.mutable_blocks()) { - blocks_.emplace_back(new BlockDesc(this, &block_desc)); - } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); - } - } - } - } + InitFromProto(); } ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE(desc_.ParseFromString(binary_str), "Fail to parse program_desc from binary string."); + InitFromProto(); +} + +void ProgramDesc::InitFromProto() { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } @@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { if (attr.type() == proto::AttrType::BLOCK) { size_t blk_idx = attr.block_idx(); op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } else if (attr.type() == proto::AttrType::BLOCKS) { + auto blks_idx = attr.blocks_idx(); + std::vector block_descs; + for (int blk_idx : blks_idx) { + block_descs.push_back(this->MutableBlock(blk_idx)); + } + op->SetBlocksAttr(attr.name(), block_descs); } } } diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 65fa0a0cfd5ba6d9b8765cee1309e118cb74348a..f3afc85eb924e4b03b7597e043ffd4e267adc977 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -76,6 +76,8 @@ class ProgramDesc { void SetFetchHolderName(const std::string &fetch_holder_name); private: + void InitFromProto(); + proto::ProgramDesc desc_; std::vector> blocks_; diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 6c46e9aad5b7fbf67fdcc07a12e7932ac8b6412b..925ea98dbe62e4da91689f6e56c135e51c24a8a3 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) { out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); + BlockDesc* new_block = program.AppendBlock(*global_block); + op = new_block->AppendOp(); + op->SetType("mul"); + + op = global_block->AppendOp(); + op->SetType("op_with_subblock"); + op->SetAttr("sub_block", new_block); + + std::vector sub_blocks; + sub_blocks.push_back(program.AppendBlock(*global_block)); + sub_blocks.push_back(program.AppendBlock(*global_block)); + op->SetAttr("sub_blocks", sub_blocks); + ProgramDesc program_copy(program); auto* global_block_copy = program_copy.MutableBlock(0); @@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) { assert_same_var("Y", y); assert_same_var("Out", out); + bool found_sub_block = false; + bool found_sub_blocks = false; for (size_t i = 0; i < global_block->OpSize(); ++i) { auto op_origin = global_block->Op(i); auto op_copy = global_block_copy->Op(i); @@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_EQ(op_copy->Proto()->SerializeAsString(), op_origin->Proto()->SerializeAsString()); - } + if (op->Type() == "op_with_subblock") { + ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); + found_sub_block = true; + + ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size()); + found_sub_blocks = true; + } + } + ASSERT_TRUE(found_sub_block); + ASSERT_TRUE(found_sub_blocks); // Not check block's protostr are same it because the order of vars could be // different and it is correct. }