From 65058cfb7ac07204cbd2dcdc05e845a447fc54f8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 23 Feb 2018 12:58:32 +0800 Subject: [PATCH] Change DFS to BFS --- paddle/fluid/framework/block_desc.cc | 38 +++++++++++++++-------- python/paddle/v2/fluid/framework.py | 46 ++++++++++++++++------------ 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 996aefd0479..1efb775cdca 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include + namespace paddle { namespace framework { @@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { if (name == kEmptyVarName) return nullptr; - auto it = vars_.find(name); - if (it != vars_.end()) { - return it->second.get(); - } + std::queue frontier; + std::unordered_set visited; - BlockDesc *tmp = ParentBlock(); + frontier.push(this); - if (tmp != nullptr) { - auto ptr = tmp->FindVarRecursive(name); - if (ptr != nullptr) { - return ptr; + while (!frontier.empty()) { // BFS + auto cur = frontier.front(); + frontier.pop(); + if (visited.count(cur) != 0) { + continue; + } + auto var = cur->FindVar(name); + if (var != nullptr) { + return var; + } + + auto fwd = cur->ForwardBlock(); + auto parent = cur->ParentBlock(); + + if (fwd != nullptr) { + frontier.push(fwd); + } + if (parent != nullptr) { + frontier.push(parent); } - } - tmp = ForwardBlock(); - if (tmp != nullptr) { - return tmp->FindVarRecursive(name); + visited.insert(cur); } return nullptr; diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 7ec04013c91..3ec8d978140 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -698,26 +698,32 @@ class Block(object): return v def var_recursive(self, name): - if self.has_var(name): - return self.var(name) - else: - if self.idx == 0: - raise ValueError( - "var {0} is not in block({1}) nor its parents.".format( - name, self.idx)) - else: - # DFS - try: - parent_block = self.program.block(self.parent_idx) - return parent_block.var_recursive(name) - except ValueError: - fwd_block = self.program.block( - self.forward_block_idx - ) if self.forward_block_idx != -1 else None - if fwd_block is not None: - return fwd_block.var_recursive(name) - else: - raise + frontier = list() + visited = set() + + frontier.append(self) + + prog = self.program + + while len(frontier) != 0: # BFS + cur = frontier[0] + frontier = frontier[1:] + + if id(cur) in visited: + continue + + if cur.has_var(name): + return cur.var(name) + + if cur.parent_idx != -1: + frontier.append(prog.block(cur.parent_idx)) + + if cur.forward_block_idx != -1: + frontier.append(prog.block(cur.forward_block_idx)) + + visited.add(id(cur)) + + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) -- GitLab