From af926306663f6998c8881537caec991299252d89 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 26 Mar 2020 09:58:25 +0800 Subject: [PATCH] fix bug of loop_vars in loop_transformer.test=develop (#23180) --- .../dygraph_to_static/loop_transformer.py | 3 +-- .../unittests/dygraph_to_static/test_loop.py | 27 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index f7097b11985..09ac73057af 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -169,7 +169,7 @@ class NameVisitor(gast.NodeVisitor): if self._is_call_func_name_node(node): self.generic_visit(node) return - if node.id == "False" or node.id == "True": + if node.id == "False" or node.id == "True" or node.id == "None": self.generic_visit(node) return @@ -187,7 +187,6 @@ class NameVisitor(gast.NodeVisitor): def visit_Attribute(self, node): if self._is_call_func_name_node(node): return - attr_full_name = get_attribute_full_name(node) self.current_seen_vars.add(node) for loop_node in self.current_loop: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 7ef456692d9..7c500a05c25 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -35,6 +35,17 @@ def while_loop_dyfunc(x): return i +def while_loop_dyfunc_with_none(x): + i = fluid.dygraph.to_variable(x)\ + if x is not None \ + else fluid.dygraph.to_variable(x+1) + flag = 1 + while x < 10: + i = i + x if flag is not None else x + i + x = x + 1 + return i + + def for_loop_dyfunc(max_len): for i in range(max_len): ret = fluid.layers.zeros(shape=[1], dtype='float32') @@ -58,9 +69,14 @@ def var_create_in_for_loop(max_len): class TestNameVisitor(unittest.TestCase): def setUp(self): - self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] - self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])] - self.create_var_names = [set(), set(["ret"])] + self.loop_funcs = [ + while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none + ] + self.loop_var_names = [ + set(["i", "x"]), set(["i", "ret", "max_len"]), + set(["i", "x", "flag"]) + ] + self.create_var_names = [set(), set(["ret"]), set()] def test_loop_vars(self): for i in range(len(self.loop_funcs)): @@ -115,6 +131,11 @@ class TestTransformWhileLoop(unittest.TestCase): # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestTransformWhileLoopWithNone(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_dyfunc_with_none + + class TestWhileLoopBoolOp(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_bool_op -- GitLab