diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 79d24c05184713d2fff6005ab9bde25af0a27570..de788487feabc7f01b8c26bbd62e4d9a595a34fd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -238,11 +238,16 @@ class NameVisitor(gast.NodeVisitor): return new_name_ids def _is_call_func_name_node(self, node): + white_func_names = set(['append', 'extend']) if len(self.ancestor_nodes) > 1: assert self.ancestor_nodes[-1] == node parent_node = self.ancestor_nodes[-2] if isinstance(parent_node, gast.Call) and parent_node.func == node: - return True + # e.g: var_list.append(elem), var_list is also a name_id. + should_skip = isinstance( + node, gast.Attribute) and node.attr in white_func_names + if not should_skip: + return True return False def _update_name_ids(self, new_name_ids): @@ -398,10 +403,13 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ]) def _vars_loaded_before_store(ids_dict): + """ + gast.Param is also a kind of `load` semantic. + """ new_dict = defaultdict(list) for k, ctxs in six.iteritems(ids_dict): for ctx in ctxs: - if isinstance(ctx, gast.Load): + if isinstance(ctx, (gast.Load, gast.Param)): new_dict[k].append(ctx) elif isinstance(ctx, gast.Store): break diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 419150345b8f4c36854767640d01a93aba5f170e..5db1bb2a384f582c30a7877e49745cd9582e096e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -342,5 +342,28 @@ class TestDiffModeNet2(TestDiffModeNet): self.Net = DiffModeNet2 +class TestNewVarCreateInOneBranch(unittest.TestCase): + def test_var_used_in_another_for(self): + def case_func(training): + # targets and targets_list is dynamically defined by training + if training: + targets = [1, 2, 3] + targets_list = [targets] + + num_step = 3 + for i in range(num_step): + if i > 0: + rois, rosi_num = 1, 2 + # targets is in loop_vars. + if training: + ros, rosi_num, targets = -1, -2, [-1, -2, -3] + targets_list.append(targets) + + return rosi_num + + self.assertEqual(paddle.jit.to_static(case_func)(False), 2) + self.assertEqual(paddle.jit.to_static(case_func)(True), -2) + + if __name__ == '__main__': unittest.main()