提交 65058cfb 编写于 作者: Y Yu Yang

Change DFS to BFS

上级 14f83707
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const { ...@@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const {
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name); std::queue<const BlockDesc *> frontier;
if (it != vars_.end()) { std::unordered_set<const BlockDesc *> visited;
return it->second.get();
}
BlockDesc *tmp = ParentBlock(); frontier.push(this);
if (tmp != nullptr) { while (!frontier.empty()) { // BFS
auto ptr = tmp->FindVarRecursive(name); auto cur = frontier.front();
if (ptr != nullptr) { frontier.pop();
return ptr; if (visited.count(cur) != 0) {
continue;
}
auto var = cur->FindVar(name);
if (var != nullptr) {
return var;
}
auto fwd = cur->ForwardBlock();
auto parent = cur->ParentBlock();
if (fwd != nullptr) {
frontier.push(fwd);
}
if (parent != nullptr) {
frontier.push(parent);
} }
}
tmp = ForwardBlock(); visited.insert(cur);
if (tmp != nullptr) {
return tmp->FindVarRecursive(name);
} }
return nullptr; return nullptr;
......
...@@ -698,26 +698,32 @@ class Block(object): ...@@ -698,26 +698,32 @@ class Block(object):
return v return v
def var_recursive(self, name): def var_recursive(self, name):
if self.has_var(name): frontier = list()
return self.var(name) visited = set()
else:
if self.idx == 0: frontier.append(self)
raise ValueError(
"var {0} is not in block({1}) nor its parents.".format( prog = self.program
name, self.idx))
else: while len(frontier) != 0: # BFS
# DFS cur = frontier[0]
try: frontier = frontier[1:]
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name) if id(cur) in visited:
except ValueError: continue
fwd_block = self.program.block(
self.forward_block_idx if cur.has_var(name):
) if self.forward_block_idx != -1 else None return cur.var(name)
if fwd_block is not None:
return fwd_block.var_recursive(name) if cur.parent_idx != -1:
else: frontier.append(prog.block(cur.parent_idx))
raise
if cur.forward_block_idx != -1:
frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur))
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册