未验证 提交 0d463d3b 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix NameVisitor bugs (#22847)

1. copy.deepcopy in NameVisitor should be changed to copy.copy to make hash or set work
2. read_context should be type of gast.Load()/gast.AugLoad(), not gast.Load/gast.AugLoad
上级 f686310d
......@@ -80,11 +80,11 @@ class NameVisitor(gast.NodeVisitor):
return True
def get_loop_var_names(self, node):
assert isinstance(node, gast.While) or isinstance(
while_node, gast.For), "Input node is not gast loop node"
assert isinstance(node, (gast.While,
gast.For)), "Input node is not gast loop node"
loop_var_names = set()
create_var_names = set()
read_context = {type(gast.Load), type(gast.AugLoad)}
read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = set(name.id for name in in_loop_vars)
......@@ -114,13 +114,13 @@ class NameVisitor(gast.NodeVisitor):
def visit_For(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def visit_While(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,7 +21,7 @@ import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor
SEED = 2020
np.random.seed(SEED)
......@@ -37,8 +37,15 @@ def while_loop_dyfunc(x):
class TestNameVisitor(unittest.TestCase):
def test_loop_vars(self):
#TODO
pass
test_func = inspect.getsource(while_loop_dyfunc)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
for node in gast.walk(gast_root):
if isinstance(node, gast.While):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, set(["i", "x"]))
self.assertEqual(create_var_names, set())
class TestTransformWhile(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册