提交 a3aca2a3 编写于 作者: F fengjiayi

fix bugs

上级 2a3c58d3
...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) ...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, prog, this)); ops_.emplace_back(new OpDesc(op_desc, this));
} }
} }
......
...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
need_update_ = true; need_update_ = true;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
......
...@@ -33,13 +33,14 @@ class OpDesc { ...@@ -33,13 +33,14 @@ class OpDesc {
OpDesc(const std::string &type, const VariableNameMap &inputs, OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); OpDesc(const proto::OpDesc &desc, BlockDesc *block);
explicit OpDesc(BlockDesc *block) : block_(block) {} explicit OpDesc(BlockDesc *block) : block_(block) {}
OpDesc(const OpDesc &other, BlockDesc *block) { OpDesc(const OpDesc &other, BlockDesc *block) {
*this = other; *this = other;
block_ = block; block_ = block;
need_update_ = true;
} }
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
......
...@@ -89,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -89,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
} }
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册