From e05e27a7f71ddb6549e406f0fbc339c789373935 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 22 Sep 2017 16:59:15 -0700 Subject: [PATCH] Fix bug --- paddle/pybind/protobuf.cc | 64 +++++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 0fb78bf7a24..5511841c8b5 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -42,15 +42,23 @@ inline void VectorToRepeated(const std::vector &vec, class ProgramDescBind; class OpDescBind; class BlockDescBind; +class VarDescBind; -class OpDescBind { +class VarDescBind { public: - explicit OpDescBind(BlockDescBind *block) : block_(block) {} + explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); } + + VarDesc *Proto() { return &var_desc_; } + +private: + VarDesc var_desc_; +}; - operator OpDesc *() { return &op_desc_; } +class OpDescBind { +public: + OpDesc *Proto() { return &op_desc_; } private: - BlockDescBind *block_; OpDesc op_desc_; }; @@ -59,14 +67,28 @@ public: BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} + BlockDescBind(const BlockDescBind &o) = delete; + BlockDescBind &operator=(const BlockDescBind &o) = delete; + int32_t id() const { return desc_->idx(); } int32_t Parent() const { return desc_->parent_idx(); } + VarDescBind *NewVar(const std::string &name) { + need_update_ = true; + auto it = vars_.find(name); + PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); + auto var = new VarDescBind(name); + vars_[name].reset(var); + return var; + } + + BlockDescBind *ParentBlock() const; + OpDescBind *AppendOp() { need_update_ = true; - ops_.emplace_back(this); - return &ops_.back(); + ops_.emplace_back(new OpDescBind()); + return ops_.back().get(); } void Sync() { @@ -75,8 +97,9 @@ public: op_field.Clear(); op_field.Reserve(static_cast(ops_.size())); for (auto &op_desc : ops_) { - op_field.AddAllocated(op_desc); + op_field.AddAllocated(op_desc->Proto()); } + need_update_ = false; } } @@ -85,7 +108,8 @@ private: BlockDesc *desc_; // not_own bool need_update_; - std::deque ops_; + std::deque> ops_; + std::unordered_map> vars_; }; using ProgDescMap = @@ -106,18 +130,20 @@ public: } return *ptr; } + ProgramDescBind(const ProgramDescBind &o) = delete; + ProgramDescBind &operator=(const ProgramDescBind &o) = delete; BlockDescBind *AppendBlock(const BlockDescBind &parent) { auto *b = prog_->add_blocks(); b->set_parent_idx(parent.id()); b->set_idx(prog_->blocks_size() - 1); - blocks_.emplace_back(this, b); - return &blocks_.back(); + blocks_.emplace_back(new BlockDescBind(this, b)); + return blocks_.back().get(); } - BlockDescBind *Root() { return &blocks_.front(); } + BlockDescBind *Root() { return blocks_.front().get(); } - BlockDescBind *Block(size_t idx) { return &blocks_[idx]; } + BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } std::string DebugString() { return Proto()->DebugString(); } @@ -125,25 +151,31 @@ public: ProgramDesc *Proto() { for (auto &block : blocks_) { - block.Sync(); + block->Sync(); } return prog_; } private: explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) { - blocks_.reserve(100); for (auto &block : *prog->mutable_blocks()) { - blocks_.emplace_back(this, &block); + blocks_.emplace_back(new BlockDescBind(this, &block)); } } // Not owned ProgramDesc *prog_; - std::vector blocks_; + std::vector> blocks_; }; +BlockDescBind *BlockDescBind::ParentBlock() const { + if (this->desc_->parent_idx() == -1) { + return nullptr; + } + return prog_->Block(static_cast(this->desc_->parent_idx())); +} + void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") .def_static("instance", -- GitLab