未验证 提交 0d463d3b 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix NameVisitor bugs (#22847)

1. copy.deepcopy in NameVisitor should be changed to copy.copy to make hash or set work
2. read_context should be type of gast.Load()/gast.AugLoad(), not gast.Load/gast.AugLoad
上级 f686310d
...@@ -80,11 +80,11 @@ class NameVisitor(gast.NodeVisitor): ...@@ -80,11 +80,11 @@ class NameVisitor(gast.NodeVisitor):
return True return True
def get_loop_var_names(self, node): def get_loop_var_names(self, node):
assert isinstance(node, gast.While) or isinstance( assert isinstance(node, (gast.While,
while_node, gast.For), "Input node is not gast loop node" gast.For)), "Input node is not gast loop node"
loop_var_names = set() loop_var_names = set()
create_var_names = set() create_var_names = set()
read_context = {type(gast.Load), type(gast.AugLoad)} read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node] in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = set(name.id for name in in_loop_vars) in_loop_name_strs = set(name.id for name in in_loop_vars)
...@@ -114,13 +114,13 @@ class NameVisitor(gast.NodeVisitor): ...@@ -114,13 +114,13 @@ class NameVisitor(gast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
self.current_loop.append(node) self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node) self.generic_visit(node)
self.current_loop.pop() self.current_loop.pop()
def visit_While(self, node): def visit_While(self, node):
self.current_loop.append(node) self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node) self.generic_visit(node)
self.current_loop.pop() self.current_loop.pop()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ import paddle.fluid as fluid ...@@ -21,7 +21,7 @@ import paddle.fluid as fluid
import unittest import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_graph from paddle.fluid.dygraph.jit import dygraph_to_static_graph
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
...@@ -37,8 +37,15 @@ def while_loop_dyfunc(x): ...@@ -37,8 +37,15 @@ def while_loop_dyfunc(x):
class TestNameVisitor(unittest.TestCase): class TestNameVisitor(unittest.TestCase):
def test_loop_vars(self): def test_loop_vars(self):
#TODO test_func = inspect.getsource(while_loop_dyfunc)
pass gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
for node in gast.walk(gast_root):
if isinstance(node, gast.While):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, set(["i", "x"]))
self.assertEqual(create_var_names, set())
class TestTransformWhile(unittest.TestCase): class TestTransformWhile(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册