diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4bf0a456b5200eb93c2bcd947aac58283ac7b41a..089792059465c60da43d02e8389f4e36900c2292 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1084,19 +1084,15 @@ class Block(object): raise ValueError("var %s not in this block" % name) return v - def _var_recursive(self, name): + def _find_var_recursive(self, name): """ Get a Variable by name from this block recursively. Args: name(str): the Variable's name. - Raises: - ValueError: this block and this parent block doesn't - have a Variable with the giving name. - Returns: - Variable: the Variable with the giving name. + Variable: the Variable with the giving name. Or None if not found. """ frontier = list() visited = set() @@ -1122,8 +1118,27 @@ class Block(object): frontier.append(prog.block(cur.forward_block_idx)) visited.add(id(cur)) + return None - raise ValueError("Var {0} is not found recursively".format(name)) + def _var_recursive(self, name): + """ + Get a Variable by name from this block recursively. + + Args: + name(str): the Variable's name. + + Raises: + ValueError: this block and this parent block doesn't + have a Variable with the giving name. + + Returns: + Variable: the Variable with the giving name. + """ + var = self._find_var_recursive(name) + if var: + return var + else: + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 05138bf94598f649ef7fdbaa94896b6ba0884416..b7e39685691809d04ecddc21d2d04a7a85e478d5 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -717,8 +717,9 @@ class While(object): out_vars = [] for inner_out_name in inner_outputs: - if inner_out_name in parent_block.vars: - out_vars.append(parent_block.var(inner_out_name)) + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_vars.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) @@ -1264,10 +1265,11 @@ class ConditionalBlock(object): if each_name not in input_set ] - out_list = [ - parent_block.var(var_name) for var_name in parent_block.vars - if var_name in intermediate - ] + out_list = [] + for inner_out_name in intermediate: + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_list.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES)