From 4636d13616a1c7d4475fdc1135747c74dd38b7a8 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 9 Apr 2021 15:17:10 +0800 Subject: [PATCH] [Dy2Stat] Fix undefined var used in For (#32153) * fix undefind var in For * fix code style --- .../dygraph_to_static/ifelse_transformer.py | 12 ++++++++-- .../dygraph_to_static/test_ifelse.py | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 79d24c05184..de788487fea 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -238,11 +238,16 @@ class NameVisitor(gast.NodeVisitor): return new_name_ids def _is_call_func_name_node(self, node): + white_func_names = set(['append', 'extend']) if len(self.ancestor_nodes) > 1: assert self.ancestor_nodes[-1] == node parent_node = self.ancestor_nodes[-2] if isinstance(parent_node, gast.Call) and parent_node.func == node: - return True + # e.g: var_list.append(elem), var_list is also a name_id. + should_skip = isinstance( + node, gast.Attribute) and node.attr in white_func_names + if not should_skip: + return True return False def _update_name_ids(self, new_name_ids): @@ -398,10 +403,13 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ]) def _vars_loaded_before_store(ids_dict): + """ + gast.Param is also a kind of `load` semantic. + """ new_dict = defaultdict(list) for k, ctxs in six.iteritems(ids_dict): for ctx in ctxs: - if isinstance(ctx, gast.Load): + if isinstance(ctx, (gast.Load, gast.Param)): new_dict[k].append(ctx) elif isinstance(ctx, gast.Store): break diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 419150345b8..5db1bb2a384 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -342,5 +342,28 @@ class TestDiffModeNet2(TestDiffModeNet): self.Net = DiffModeNet2 +class TestNewVarCreateInOneBranch(unittest.TestCase): + def test_var_used_in_another_for(self): + def case_func(training): + # targets and targets_list is dynamically defined by training + if training: + targets = [1, 2, 3] + targets_list = [targets] + + num_step = 3 + for i in range(num_step): + if i > 0: + rois, rosi_num = 1, 2 + # targets is in loop_vars. + if training: + ros, rosi_num, targets = -1, -2, [-1, -2, -3] + targets_list.append(targets) + + return rosi_num + + self.assertEqual(paddle.jit.to_static(case_func)(False), 2) + self.assertEqual(paddle.jit.to_static(case_func)(True), -2) + + if __name__ == '__main__': unittest.main() -- GitLab