diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b991187d424108db176ebd6996d7d161f11dcd3d..f8e3cd3a3208b6d79eeeef015388923464679856 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1040,19 +1040,15 @@ class Block(object): raise ValueError("var %s not in this block" % name) return v - def _var_recursive(self, name): + def _find_var_recursive(self, name): """ Get a Variable by name from this block recursively. Args: name(str): the Variable's name. - Raises: - ValueError: this block and this parent block doesn't - have a Variable with the giving name. - Returns: - Variable: the Variable with the giving name. + Variable: the Variable with the giving name. Or None if not found. """ frontier = list() visited = set() @@ -1078,8 +1074,27 @@ class Block(object): frontier.append(prog.block(cur.forward_block_idx)) visited.add(id(cur)) + return None - raise ValueError("Var {0} is not found recursively".format(name)) + def _var_recursive(self, name): + """ + Get a Variable by name from this block recursively. + + Args: + name(str): the Variable's name. + + Raises: + ValueError: this block and this parent block doesn't + have a Variable with the giving name. + + Returns: + Variable: the Variable with the giving name. + """ + var = self._find_var_recursive(name) + if var: + return var + else: + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 05138bf94598f649ef7fdbaa94896b6ba0884416..b7e39685691809d04ecddc21d2d04a7a85e478d5 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -717,8 +717,9 @@ class While(object): out_vars = [] for inner_out_name in inner_outputs: - if inner_out_name in parent_block.vars: - out_vars.append(parent_block.var(inner_out_name)) + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_vars.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) @@ -1264,10 +1265,11 @@ class ConditionalBlock(object): if each_name not in input_set ] - out_list = [ - parent_block.var(var_name) for var_name in parent_block.vars - if var_name in intermediate - ] + out_list = [] + for inner_out_name in intermediate: + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_list.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES)