未验证 提交 4636d136 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix undefined var used in For (#32153)

* fix undefind var in For

* fix code style
上级 95122ebe
......@@ -238,11 +238,16 @@ class NameVisitor(gast.NodeVisitor):
return new_name_ids
def _is_call_func_name_node(self, node):
white_func_names = set(['append', 'extend'])
if len(self.ancestor_nodes) > 1:
assert self.ancestor_nodes[-1] == node
parent_node = self.ancestor_nodes[-2]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
# e.g: var_list.append(elem), var_list is also a name_id.
should_skip = isinstance(
node, gast.Attribute) and node.attr in white_func_names
if not should_skip:
return True
return False
def _update_name_ids(self, new_name_ids):
......@@ -398,10 +403,13 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
])
def _vars_loaded_before_store(ids_dict):
"""
gast.Param is also a kind of `load` semantic.
"""
new_dict = defaultdict(list)
for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs:
if isinstance(ctx, gast.Load):
if isinstance(ctx, (gast.Load, gast.Param)):
new_dict[k].append(ctx)
elif isinstance(ctx, gast.Store):
break
......
......@@ -342,5 +342,28 @@ class TestDiffModeNet2(TestDiffModeNet):
self.Net = DiffModeNet2
class TestNewVarCreateInOneBranch(unittest.TestCase):
def test_var_used_in_another_for(self):
def case_func(training):
# targets and targets_list is dynamically defined by training
if training:
targets = [1, 2, 3]
targets_list = [targets]
num_step = 3
for i in range(num_step):
if i > 0:
rois, rosi_num = 1, 2
# targets is in loop_vars.
if training:
ros, rosi_num, targets = -1, -2, [-1, -2, -3]
targets_list.append(targets)
return rosi_num
self.assertEqual(paddle.jit.to_static(case_func)(False), 2)
self.assertEqual(paddle.jit.to_static(case_func)(True), -2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册