提交 65058cfb 编写于 作者: Y Yu Yang

Change DFS to BFS

上级 14f83707
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle {
namespace framework {
......@@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const {
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
}
std::queue<const BlockDesc *> frontier;
std::unordered_set<const BlockDesc *> visited;
BlockDesc *tmp = ParentBlock();
frontier.push(this);
if (tmp != nullptr) {
auto ptr = tmp->FindVarRecursive(name);
if (ptr != nullptr) {
return ptr;
while (!frontier.empty()) { // BFS
auto cur = frontier.front();
frontier.pop();
if (visited.count(cur) != 0) {
continue;
}
auto var = cur->FindVar(name);
if (var != nullptr) {
return var;
}
auto fwd = cur->ForwardBlock();
auto parent = cur->ParentBlock();
if (fwd != nullptr) {
frontier.push(fwd);
}
if (parent != nullptr) {
frontier.push(parent);
}
}
tmp = ForwardBlock();
if (tmp != nullptr) {
return tmp->FindVarRecursive(name);
visited.insert(cur);
}
return nullptr;
......
......@@ -698,26 +698,32 @@ class Block(object):
return v
def var_recursive(self, name):
if self.has_var(name):
return self.var(name)
else:
if self.idx == 0:
raise ValueError(
"var {0} is not in block({1}) nor its parents.".format(
name, self.idx))
else:
# DFS
try:
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name)
except ValueError:
fwd_block = self.program.block(
self.forward_block_idx
) if self.forward_block_idx != -1 else None
if fwd_block is not None:
return fwd_block.var_recursive(name)
else:
raise
frontier = list()
visited = set()
frontier.append(self)
prog = self.program
while len(frontier) != 0: # BFS
cur = frontier[0]
frontier = frontier[1:]
if id(cur) in visited:
continue
if cur.has_var(name):
return cur.var(name)
if cur.parent_idx != -1:
frontier.append(prog.block(cur.parent_idx))
if cur.forward_block_idx != -1:
frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur))
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self):
return list(self.iter_parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册