diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index f3f297c6a6350bf659fda9e80e1f364f7bff8f8a..2314d3c6c9500420e0f053fbf986fcd1e2dedbd5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -80,11 +80,11 @@ class NameVisitor(gast.NodeVisitor): return True def get_loop_var_names(self, node): - assert isinstance(node, gast.While) or isinstance( - while_node, gast.For), "Input node is not gast loop node" + assert isinstance(node, (gast.While, + gast.For)), "Input node is not gast loop node" loop_var_names = set() create_var_names = set() - read_context = {type(gast.Load), type(gast.AugLoad)} + read_context = {type(gast.Load()), type(gast.AugLoad())} in_loop_vars = self.in_loop_vars[node] in_loop_name_strs = set(name.id for name in in_loop_vars) @@ -114,13 +114,13 @@ class NameVisitor(gast.NodeVisitor): def visit_For(self, node): self.current_loop.append(node) - self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) + self.before_loop_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() def visit_While(self, node): self.current_loop.append(node) - self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars) + self.before_loop_vars[node] = copy.copy(self.current_seen_vars) self.generic_visit(node) self.current_loop.pop() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py index 9ca551f91b88d264b02dec84f6dd1cefee27933c..c9594ff17147b74793df34cbb4ed9b7fd7e0d018 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_loop.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import paddle.fluid as fluid import unittest from paddle.fluid.dygraph.jit import dygraph_to_static_graph -#from paddle.fluid.dygraph.dygraph_to_static import NameVistor +from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor SEED = 2020 np.random.seed(SEED) @@ -37,8 +37,15 @@ def while_loop_dyfunc(x): class TestNameVisitor(unittest.TestCase): def test_loop_vars(self): - #TODO - pass + test_func = inspect.getsource(while_loop_dyfunc) + gast_root = gast.parse(test_func) + name_visitor = NameVisitor(gast_root) + for node in gast.walk(gast_root): + if isinstance(node, gast.While): + loop_var_names, create_var_names = name_visitor.get_loop_var_names( + node) + self.assertEqual(loop_var_names, set(["i", "x"])) + self.assertEqual(create_var_names, set()) class TestTransformWhile(unittest.TestCase):