From 14f8370738236fdd0de2e5f6c6bbf9c6d2d23e6a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 22 Feb 2018 11:13:23 +0800 Subject: [PATCH] Add block.fwd_block_id --- paddle/fluid/framework/block_desc.cc | 38 ++++++++++++++++++++------ paddle/fluid/framework/block_desc.h | 8 +++++- paddle/fluid/framework/framework.proto | 1 + paddle/fluid/framework/program_desc.h | 8 +++++- paddle/fluid/operators/while_op.cc | 7 +++-- paddle/fluid/pybind/protobuf.cc | 2 ++ python/paddle/v2/fluid/backward.py | 5 +++- python/paddle/v2/fluid/framework.py | 26 +++++++++++++++--- 8 files changed, 78 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 0dd37e7df06..996aefd0479 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -46,11 +46,25 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { if (name == kEmptyVarName) return nullptr; auto it = vars_.find(name); - if (it == vars_.end()) { - return Parent() == kNoneBlockIndex ? nullptr - : ParentBlock()->FindVarRecursive(name); + if (it != vars_.end()) { + return it->second.get(); } - return it->second.get(); + + BlockDesc *tmp = ParentBlock(); + + if (tmp != nullptr) { + auto ptr = tmp->FindVarRecursive(name); + if (ptr != nullptr) { + return ptr; + } + } + + tmp = ForwardBlock(); + if (tmp != nullptr) { + return tmp->FindVarRecursive(name); + } + + return nullptr; } VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { @@ -136,10 +150,7 @@ void BlockDesc::Flush() { } BlockDesc *BlockDesc::ParentBlock() const { - if (this->desc_->parent_idx() == kNoneBlockIndex) { - return nullptr; - } - return prog_->MutableBlock(static_cast(this->desc_->parent_idx())); + return prog_->MutableBlock(static_cast(desc_->parent_idx())); } proto::BlockDesc *BlockDesc::Proto() { @@ -186,5 +197,16 @@ void BlockDesc::ClearPBVars() { } } +void BlockDesc::SetForwardBlockID(int32_t forward_block_id) { + PADDLE_ENFORCE(!desc_->has_forward_block_idx(), + "Parent block ID has been set to %d. Cannot set to %d", + desc_->forward_block_idx(), forward_block_id); + desc_->set_forward_block_idx(forward_block_id); +} + +BlockDesc *BlockDesc::ForwardBlock() const { + return prog_->MutableBlock(static_cast(desc_->forward_block_idx())); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 4e2b03e245f..8345934a71c 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -49,6 +49,8 @@ class BlockDesc { int32_t Parent() const { return desc_->parent_idx(); } + int32_t ForwardBlockID() const { return desc_->forward_block_idx(); } + VarDesc *Var(const std::string &name_bytes); VarDesc *FindVar(const std::string &name_bytes) const; @@ -73,6 +75,10 @@ class BlockDesc { BlockDesc *ParentBlock() const; + BlockDesc *ForwardBlock() const; + + void SetForwardBlockID(int32_t forward_block_id); + OpDesc *AppendOp(); void AppendAllocatedOp(std::unique_ptr &&op_desc); @@ -91,7 +97,7 @@ class BlockDesc { proto::BlockDesc *Proto(); - ProgramDesc *Program() { return this->prog_; } + ProgramDesc *Program() const { return this->prog_; } private: void ClearPBOps(); diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 4eb18b4e4d6..5b43f5a8a4a 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -158,6 +158,7 @@ message BlockDesc { required int32 parent_idx = 2; repeated VarDesc vars = 3; repeated OpDesc ops = 4; + optional int32 forward_block_idx = 5 [ default = -1 ]; } // Please refer to diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 8d4b999ad2f..538a0372116 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -38,7 +38,13 @@ class ProgramDesc { BlockDesc *AppendBlock(const BlockDesc &parent); - BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); } + BlockDesc *MutableBlock(size_t idx) { + if (idx == static_cast(kNoneBlockIndex)) { + return nullptr; + } else { + return blocks_[idx].get(); + } + } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; } diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 3d5cdeda26a..5f51a273dd4 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { while_grad->SetInput(kStepScopes, Output(kStepScopes)); auto *grad_block = this->grad_block_[0]; - auto *fwd_block = grad_block->ParentBlock(); + auto *fwd_block = grad_block->ForwardBlock(); + auto *parent_block = grad_block->ParentBlock(); // Not all of IGs will be generated by inner gradient operators of while op. // Ignore IGs that is not generated by the inside block. @@ -265,8 +266,10 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { for (auto &input_name : op->InputArgumentNames()) { // If the input of Op has been recorded or is generated by the forward // block, do not make it as input again. + if (block_ins.find(input_name) != block_ins.end() || - fwd_block->FindVar(input_name) != nullptr) { + fwd_block->FindVar(input_name) != nullptr || + parent_block->FindVar(input_name) != nullptr) { continue; } extra_inputs.insert(input_name); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 131971099ef..01dc53de78b 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) { py::class_(m, "BlockDesc", "") .def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("parent", &BlockDesc::Parent) + .def("get_forward_block_idx", &BlockDesc::ForwardBlockID) + .def("set_forward_block_idx", &BlockDesc::SetForwardBlockID) .def("append_op", &BlockDesc::AppendOp, py::return_value_policy::reference) .def("prepend_op", &BlockDesc::PrependOp, diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 33ff43f6930..ba27aaa2460 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -298,7 +298,8 @@ def _append_backward_ops_(block, # If the op has its own sub-block, deal with the sub-block first if op.has_attr("sub_block"): sub_block = program.block(op.block_attr("sub_block")) - grad_sub_block = program.create_block(parent_idx=sub_block.idx) + grad_sub_block = program.create_block() + grad_sub_block.set_forward_block_idx(sub_block.idx) cb = _callback_lookup_(op) if cb is not None: if callbacks is None: @@ -310,6 +311,8 @@ def _append_backward_ops_(block, else: _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, no_grad_dict, grad_to_var, callbacks) + + program.rollback() grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 0e11709296a..7ec04013c91 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -678,6 +678,13 @@ class Block(object): def parent_idx(self): return self.desc.parent + @property + def forward_block_idx(self): + return self.desc.get_forward_block_idx() + + def set_forward_block_idx(self, idx): + self.desc.set_forward_block_idx(idx) + @property def idx(self): return self.desc.id @@ -695,11 +702,22 @@ class Block(object): return self.var(name) else: if self.idx == 0: - raise ValueError("var %s is not in block(%d) nor its parents." % - name, self.idx) + raise ValueError( + "var {0} is not in block({1}) nor its parents.".format( + name, self.idx)) else: - parent_block = self.program.block(self.parent_idx) - return parent_block.var_recursive(name) + # DFS + try: + parent_block = self.program.block(self.parent_idx) + return parent_block.var_recursive(name) + except ValueError: + fwd_block = self.program.block( + self.forward_block_idx + ) if self.forward_block_idx != -1 else None + if fwd_block is not None: + return fwd_block.var_recursive(name) + else: + raise def all_parameters(self): return list(self.iter_parameters()) -- GitLab