提交 f5434507 编写于 作者: X Xin Pan

fix control_flow ops in outs

test=develop
上级 0e3048db
......@@ -1040,19 +1040,15 @@ class Block(object):
raise ValueError("var %s not in this block" % name)
return v
def _var_recursive(self, name):
def _find_var_recursive(self, name):
"""
Get a Variable by name from this block recursively.
Args:
name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns:
Variable: the Variable with the giving name.
Variable: the Variable with the giving name. Or None if not found.
"""
frontier = list()
visited = set()
......@@ -1078,8 +1074,27 @@ class Block(object):
frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur))
return None
raise ValueError("Var {0} is not found recursively".format(name))
def _var_recursive(self, name):
"""
Get a Variable by name from this block recursively.
Args:
name(str): the Variable's name.
Raises:
ValueError: this block and this parent block doesn't
have a Variable with the giving name.
Returns:
Variable: the Variable with the giving name.
"""
var = self._find_var_recursive(name)
if var:
return var
else:
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self):
return list(self.iter_parameters())
......
......@@ -717,8 +717,9 @@ class While(object):
out_vars = []
for inner_out_name in inner_outputs:
if inner_out_name in parent_block.vars:
out_vars.append(parent_block.var(inner_out_name))
inner_var = parent_block._find_var_recursive(inner_out_name)
if inner_var:
out_vars.append(inner_var)
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
......@@ -1264,10 +1265,11 @@ class ConditionalBlock(object):
if each_name not in input_set
]
out_list = [
parent_block.var(var_name) for var_name in parent_block.vars
if var_name in intermediate
]
out_list = []
for inner_out_name in intermediate:
inner_var = parent_block._find_var_recursive(inner_out_name)
if inner_var:
out_list.append(inner_var)
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册