From f5434507f081e02eb37f8f9da17165bd2298348a Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 10 Dec 2018 14:36:37 +0800 Subject: [PATCH] fix control_flow ops in outs test=develop --- python/paddle/fluid/framework.py | 29 ++++++++++++++++------ python/paddle/fluid/layers/control_flow.py | 14 ++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b991187d424..f8e3cd3a320 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1040,19 +1040,15 @@ class Block(object): raise ValueError("var %s not in this block" % name) return v - def _var_recursive(self, name): + def _find_var_recursive(self, name): """ Get a Variable by name from this block recursively. Args: name(str): the Variable's name. - Raises: - ValueError: this block and this parent block doesn't - have a Variable with the giving name. - Returns: - Variable: the Variable with the giving name. + Variable: the Variable with the giving name. Or None if not found. """ frontier = list() visited = set() @@ -1078,8 +1074,27 @@ class Block(object): frontier.append(prog.block(cur.forward_block_idx)) visited.add(id(cur)) + return None - raise ValueError("Var {0} is not found recursively".format(name)) + def _var_recursive(self, name): + """ + Get a Variable by name from this block recursively. + + Args: + name(str): the Variable's name. + + Raises: + ValueError: this block and this parent block doesn't + have a Variable with the giving name. + + Returns: + Variable: the Variable with the giving name. + """ + var = self._find_var_recursive(name) + if var: + return var + else: + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 05138bf9459..b7e39685691 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -717,8 +717,9 @@ class While(object): out_vars = [] for inner_out_name in inner_outputs: - if inner_out_name in parent_block.vars: - out_vars.append(parent_block.var(inner_out_name)) + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_vars.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) @@ -1264,10 +1265,11 @@ class ConditionalBlock(object): if each_name not in input_set ] - out_list = [ - parent_block.var(var_name) for var_name in parent_block.vars - if var_name in intermediate - ] + out_list = [] + for inner_out_name in intermediate: + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_list.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) -- GitLab