diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fbedd6c825ba65386e3e31c3482375a7e6361278..d72b64700f7cf680501fd3e355d20e694f1f097d 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include + namespace paddle { namespace framework { @@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name, VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { if (name == kEmptyVarName) return nullptr; - auto it = vars_.find(name); - if (it == vars_.end()) { - return Parent() == kNoneBlockIndex ? nullptr - : ParentBlock()->FindVarRecursive(name); + std::queue frontier; + std::unordered_set visited; + + frontier.push(this); + + while (!frontier.empty()) { // BFS + auto cur = frontier.front(); + frontier.pop(); + if (visited.count(cur) != 0) { + continue; + } + auto var = cur->FindVar(name); + if (var != nullptr) { + return var; + } + + auto fwd = cur->ForwardBlock(); + auto parent = cur->ParentBlock(); + + if (fwd != nullptr) { + frontier.push(fwd); + } + if (parent != nullptr) { + frontier.push(parent); + } + + visited.insert(cur); } - return it->second.get(); + + return nullptr; } VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { @@ -155,10 +181,7 @@ void BlockDesc::Flush() { } BlockDesc *BlockDesc::ParentBlock() const { - if (this->desc_->parent_idx() == kNoneBlockIndex) { - return nullptr; - } - return prog_->MutableBlock(static_cast(this->desc_->parent_idx())); + return prog_->MutableBlock(static_cast(desc_->parent_idx())); } proto::BlockDesc *BlockDesc::Proto() { @@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() { } } +void BlockDesc::SetForwardBlockID(int32_t forward_block_id) { + PADDLE_ENFORCE(!desc_->has_forward_block_idx(), + "Parent block ID has been set to %d. Cannot set to %d", + desc_->forward_block_idx(), forward_block_id); + desc_->set_forward_block_idx(forward_block_id); +} + +BlockDesc *BlockDesc::ForwardBlock() const { + return prog_->MutableBlock(static_cast(desc_->forward_block_idx())); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index b2375b53e3ac6bd8d82897f9a8a640178e6b7a39..3bd90f38907c0a45ae0c9bb00706e5c127f08417 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -49,6 +49,8 @@ class BlockDesc { int32_t Parent() const { return desc_->parent_idx(); } + int32_t ForwardBlockID() const { return desc_->forward_block_idx(); } + VarDesc *Var(const std::string &name_bytes); VarDesc *FindVar(const std::string &name_bytes) const; @@ -75,6 +77,10 @@ class BlockDesc { BlockDesc *ParentBlock() const; + BlockDesc *ForwardBlock() const; + + void SetForwardBlockID(int32_t forward_block_id); + OpDesc *AppendOp(); void AppendAllocatedOp(std::unique_ptr &&op_desc); @@ -93,7 +99,7 @@ class BlockDesc { proto::BlockDesc *Proto(); - ProgramDesc *Program() { return this->prog_; } + ProgramDesc *Program() const { return this->prog_; } private: void ClearPBOps(); diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 4eb18b4e4d685111d02387d5ab944146c9217e62..5b43f5a8a4a1c128b04ac206d387e30c55f533fe 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -158,6 +158,7 @@ message BlockDesc { required int32 parent_idx = 2; repeated VarDesc vars = 3; repeated OpDesc ops = 4; + optional int32 forward_block_idx = 5 [ default = -1 ]; } // Please refer to diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 8d4b999ad2fc0924d4609415172b87ac7c6357e9..538a0372116e6f90fd2fae5f00097b8efc5dcb5c 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -38,7 +38,13 @@ class ProgramDesc { BlockDesc *AppendBlock(const BlockDesc &parent); - BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); } + BlockDesc *MutableBlock(size_t idx) { + if (idx == static_cast(kNoneBlockIndex)) { + return nullptr; + } else { + return blocks_[idx].get(); + } + } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; } diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 3d5cdeda26ada94fbd8e6a7c25995aa7de93fb3d..8b62b242cf8745378eb216db10605388b294ca75 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { while_grad->SetInput(kStepScopes, Output(kStepScopes)); auto *grad_block = this->grad_block_[0]; - auto *fwd_block = grad_block->ParentBlock(); + auto *fwd_block = grad_block->ForwardBlock(); + auto *parent_block = grad_block->ParentBlock(); // Not all of IGs will be generated by inner gradient operators of while op. // Ignore IGs that is not generated by the inside block. @@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { for (auto &o : Output(kOutputs)) { block_ins.insert(o); } - std::unordered_set extra_inputs; + std::unordered_set output_grads; for (const auto *op : grad_block->AllOps()) { for (auto &input_name : op->InputArgumentNames()) { // If the input of Op has been recorded or is generated by the forward // block, do not make it as input again. + + // The input is located in I/O or other op's outputs or the variable is + // located in grad_block's parents if (block_ins.find(input_name) != block_ins.end() || - fwd_block->FindVar(input_name) != nullptr) { + (fwd_block->FindVarRecursive(input_name) != nullptr || + parent_block->FindVarRecursive(input_name) != nullptr)) { continue; } - extra_inputs.insert(input_name); + output_grads.insert(input_name); } for (auto &output_name : op->OutputArgumentNames()) { block_ins.insert(output_name); } } - std::vector extra_inputs_list; - extra_inputs_list.resize(extra_inputs.size()); - std::copy(extra_inputs.begin(), extra_inputs.end(), - extra_inputs_list.begin()); - while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); + std::vector output_grads_list; + output_grads_list.resize(output_grads.size()); + std::copy(output_grads.begin(), output_grads.end(), + output_grads_list.begin()); + while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list); while_grad->SetAttrMap(this->Attrs()); while_grad->SetBlockAttr(kStepBlock, *grad_block); // record the original output gradient names, since the gradient name of // while operator could be renamed. - while_grad->SetAttr("original_output_grad", extra_inputs_list); + while_grad->SetAttr("original_output_grad", output_grads_list); return std::unique_ptr(while_grad); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 1a9d7c421b741187390e0ea3d837e8ef1cce70e8..b725be79529c5ccdde12446b5b5c09eaf47550e6 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) { py::class_(m, "BlockDesc", "") .def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("parent", &BlockDesc::Parent) + .def("get_forward_block_idx", &BlockDesc::ForwardBlockID) + .def("set_forward_block_idx", &BlockDesc::SetForwardBlockID) .def("append_op", &BlockDesc::AppendOp, py::return_value_policy::reference) .def("prepend_op", &BlockDesc::PrependOp, diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 33ff43f69304ddd9330c61114dba85994b5f1bdd..ba27aaa24601bd72bcdbd064242ea2b1c345340c 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -298,7 +298,8 @@ def _append_backward_ops_(block, # If the op has its own sub-block, deal with the sub-block first if op.has_attr("sub_block"): sub_block = program.block(op.block_attr("sub_block")) - grad_sub_block = program.create_block(parent_idx=sub_block.idx) + grad_sub_block = program.create_block() + grad_sub_block.set_forward_block_idx(sub_block.idx) cb = _callback_lookup_(op) if cb is not None: if callbacks is None: @@ -310,6 +311,8 @@ def _append_backward_ops_(block, else: _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, no_grad_dict, grad_to_var, callbacks) + + program.rollback() grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 1cb06c52a43b3585a49d4b8bef031afef07e9b0d..78318dc6d63347f332417402c1b55809870d8fa2 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -696,6 +696,13 @@ class Block(object): def parent_idx(self): return self.desc.parent + @property + def forward_block_idx(self): + return self.desc.get_forward_block_idx() + + def set_forward_block_idx(self, idx): + self.desc.set_forward_block_idx(idx) + @property def idx(self): return self.desc.id @@ -709,15 +716,32 @@ class Block(object): return v def var_recursive(self, name): - if self.has_var(name): - return self.var(name) - else: - if self.idx == 0: - raise ValueError("var %s is not in block(%d) nor its parents." % - name, self.idx) - else: - parent_block = self.program.block(self.parent_idx) - return parent_block.var_recursive(name) + frontier = list() + visited = set() + + frontier.append(self) + + prog = self.program + + while len(frontier) != 0: # BFS + cur = frontier[0] + frontier = frontier[1:] + + if id(cur) in visited: + continue + + if cur.has_var(name): + return cur.var(name) + + if cur.parent_idx != -1: + frontier.append(prog.block(cur.parent_idx)) + + if cur.forward_block_idx != -1: + frontier.append(prog.block(cur.forward_block_idx)) + + visited.add(id(cur)) + + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index ee56ccdcf1175b1e6733166f1e7e41dd6c3e3298..6952ca7fe49931b9ebc84e214569c47d632d2a06 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -223,15 +223,15 @@ def get_cfgs(input_program): # Find while/while_grad block pair for grad_id in while_grad_sub_block_ids: - parent_id = pdesc.block(grad_id).parent - if parent_id in while_sub_block_ids: - while_block_id_pair.append((parent_id, grad_id)) - while_sub_block_ids.remove(parent_id) + forward_id = pdesc.block(grad_id).get_forward_block_idx() + if forward_id in while_sub_block_ids: + while_block_id_pair.append((forward_id, grad_id)) + while_sub_block_ids.remove(forward_id) # Get while/while_grad block ops - for parent_id, grad_id in while_block_id_pair: + for forward_id, grad_id in while_block_id_pair: while_block_ops = [] - while_block = pdesc.block(parent_id) + while_block = pdesc.block(forward_id) while_block_op_size = while_block.op_size() for i in range(while_block_op_size): while_block_ops.append(while_block.op(i)) @@ -242,21 +242,21 @@ def get_cfgs(input_program): while_block_ops.append(while_grad_block.op(i)) while_op_output = set() - while_op_output.update(while_op_dict[parent_id].output_arg_names()) + while_op_output.update(while_op_dict[forward_id].output_arg_names()) while_op_output.update(while_op_dict[grad_id].output_arg_names()) ops_list.append((while_block_ops, while_block_op_size, while_op_output)) # Process rest while block ops - for parent_id in while_sub_block_ids: + for forward_id in while_sub_block_ids: while_block_ops = [] - while_block = pdesc.block(parent_id) + while_block = pdesc.block(forward_id) while_block_op_size = while_block.op_size() for i in range(while_block_op_size): while_block_ops.append(while_block.op(i)) while_op_output = set() - while_op_output.update(while_op_dict[parent_id].output_arg_names()) + while_op_output.update(while_op_dict[forward_id].output_arg_names()) ops_list.append((while_block_ops, while_block_op_size, while_op_output))