From a3aca2a3cfeb6ab246ff95987374809be1a3c863 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 31 May 2018 20:58:26 +0800 Subject: [PATCH] fix bugs --- paddle/fluid/framework/block_desc.cc | 2 +- paddle/fluid/framework/op_desc.cc | 2 +- paddle/fluid/framework/op_desc.h | 3 ++- paddle/fluid/framework/program_desc.cc | 10 ++++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index b15aba91069..e7842e9b813 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) vars_[var_desc.name()].reset(new VarDesc(var_desc)); } for (const proto::OpDesc &op_desc : desc_->ops()) { - ops_.emplace_back(new OpDesc(op_desc, prog, this)); + ops_.emplace_back(new OpDesc(op_desc, this)); } } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 09b67e5a174..f92769192c2 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { need_update_ = true; } -OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) +OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) : desc_(desc), need_update_(false) { // restore inputs_ int input_size = desc_.inputs_size(); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 1a330db7cc5..a02d3e26912 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -33,13 +33,14 @@ class OpDesc { OpDesc(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs); - OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); + OpDesc(const proto::OpDesc &desc, BlockDesc *block); explicit OpDesc(BlockDesc *block) : block_(block) {} OpDesc(const OpDesc &other, BlockDesc *block) { *this = other; block_ = block; + need_update_ = true; } void CopyFrom(const OpDesc &op_desc); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index aa01f9928cf..1e01a6e9004 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -89,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } + for (auto &block : blocks_) { + for (auto *op : block->AllOps()) { + for (const auto &attr : op->Proto()->attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + size_t blk_idx = attr.block_idx(); + op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } + } + } + } } const std::vector ProgramDesc::GetFeedTargetNames() { -- GitLab