提交 f5434507 编写于 作者: X Xin Pan

fix control_flow ops in outs

test=develop
上级 0e3048db
...@@ -1040,19 +1040,15 @@ class Block(object): ...@@ -1040,19 +1040,15 @@ class Block(object):
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
return v return v
def _var_recursive(self, name): def _find_var_recursive(self, name):
""" """
Get a Variable by name from this block recursively. Get a Variable by name from this block recursively.
Args: Args:
name(str): the Variable's name. name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name. Or None if not found.
""" """
frontier = list() frontier = list()
visited = set() visited = set()
...@@ -1078,8 +1074,27 @@ class Block(object): ...@@ -1078,8 +1074,27 @@ class Block(object):
frontier.append(prog.block(cur.forward_block_idx)) frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur)) visited.add(id(cur))
return None
raise ValueError("Var {0} is not found recursively".format(name)) def _var_recursive(self, name):
"""
Get a Variable by name from this block recursively.
Args:
name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns:
Variable: the Variable with the giving name.
"""
var = self._find_var_recursive(name)
if var:
return var
else:
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
......
...@@ -717,8 +717,9 @@ class While(object): ...@@ -717,8 +717,9 @@ class While(object):
out_vars = [] out_vars = []
for inner_out_name in inner_outputs: for inner_out_name in inner_outputs:
if inner_out_name in parent_block.vars: inner_var = parent_block._find_var_recursive(inner_out_name)
out_vars.append(parent_block.var(inner_out_name)) if inner_var:
out_vars.append(inner_var)
step_scope = parent_block.create_var( step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES)
...@@ -1264,10 +1265,11 @@ class ConditionalBlock(object): ...@@ -1264,10 +1265,11 @@ class ConditionalBlock(object):
if each_name not in input_set if each_name not in input_set
] ]
out_list = [ out_list = []
parent_block.var(var_name) for var_name in parent_block.vars for inner_out_name in intermediate:
if var_name in intermediate inner_var = parent_block._find_var_recursive(inner_out_name)
] if inner_var:
out_list.append(inner_var)
step_scope = parent_block.create_var( step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册