diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fd409ed4c0f7a504686765909e9c71692aab8824..b15aba910696150a424ead6a716ee8c1cf68b09b 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, : prog_(prog), desc_(desc) { need_update_ = true; for (auto &op : other.ops_) { - ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); + ops_.emplace_back(new OpDesc(*op, this)); } for (auto &it : other.vars_) { auto *var = new VarDesc(*it.second); diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 600601669c5d56a3ffc2fb9c804ffad5fde58f0b..189dd6c52f85b5bf623b98c64c07c0c7269505d4 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -105,7 +105,7 @@ class BlockDesc { size_t OpSize() const { return ops_.size(); } - OpDesc *Op(int idx) { return ops_.at(idx).get(); } + OpDesc *Op(int idx) const { return ops_.at(idx).get(); } void Flush(); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 64fb028f83a539d17885186d5d8ee6ef26f095e9..aa01f9928cf800dd7414c0daa6c83711fe8da681 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { + auto all_ops = blocks_[block_id]->AllOps(); + for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { + auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { + if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { + int sub_block_id = + o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name); + op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); } } }