From 2a3c58d3fecab43a753bd8c47e327ceae9f0f467 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 31 May 2018 16:56:43 +0800 Subject: [PATCH] refine programdesc copy --- paddle/fluid/framework/block_desc.cc | 2 +- paddle/fluid/framework/block_desc.h | 2 +- paddle/fluid/framework/program_desc.cc | 15 +++++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fd409ed4c0f..b15aba91069 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, : prog_(prog), desc_(desc) { need_update_ = true; for (auto &op : other.ops_) { - ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); + ops_.emplace_back(new OpDesc(*op, this)); } for (auto &it : other.vars_) { auto *var = new VarDesc(*it.second); diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 600601669c5..189dd6c52f8 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -105,7 +105,7 @@ class BlockDesc { size_t OpSize() const { return ops_.size(); } - OpDesc *Op(int idx) { return ops_.at(idx).get(); } + OpDesc *Op(int idx) const { return ops_.at(idx).get(); } void Flush(); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 64fb028f83a..aa01f9928cf 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { + auto all_ops = blocks_[block_id]->AllOps(); + for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { + auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { + if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { + int sub_block_id = + o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name); + op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); } } } -- GitLab