diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index f7097b11985e8d3d6942661393042492e18ebdfc..09ac73057af20d31c012379be39f5755fba22603 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -169,7 +169,7 @@ class NameVisitor(gast.NodeVisitor): if self._is_call_func_name_node(node): self.generic_visit(node) return - if node.id == "False" or node.id == "True": + if node.id == "False" or node.id == "True" or node.id == "None": self.generic_visit(node) return @@ -187,7 +187,6 @@ class NameVisitor(gast.NodeVisitor): def visit_Attribute(self, node): if self._is_call_func_name_node(node): return - attr_full_name = get_attribute_full_name(node) self.current_seen_vars.add(node) for loop_node in self.current_loop: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 7ef456692d9581660dbdcd31dfc395a1881bf6c1..7c500a05c2512f8cbeb38cbd9a87d0271e59222a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -35,6 +35,17 @@ def while_loop_dyfunc(x): return i +def while_loop_dyfunc_with_none(x): + i = fluid.dygraph.to_variable(x)\ + if x is not None \ + else fluid.dygraph.to_variable(x+1) + flag = 1 + while x < 10: + i = i + x if flag is not None else x + i + x = x + 1 + return i + + def for_loop_dyfunc(max_len): for i in range(max_len): ret = fluid.layers.zeros(shape=[1], dtype='float32') @@ -58,9 +69,14 @@ def var_create_in_for_loop(max_len): class TestNameVisitor(unittest.TestCase): def setUp(self): - self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] - self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])] - self.create_var_names = [set(), set(["ret"])] + self.loop_funcs = [ + while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none + ] + self.loop_var_names = [ + set(["i", "x"]), set(["i", "ret", "max_len"]), + set(["i", "x", "flag"]) + ] + self.create_var_names = [set(), set(["ret"]), set()] def test_loop_vars(self): for i in range(len(self.loop_funcs)): @@ -115,6 +131,11 @@ class TestTransformWhileLoop(unittest.TestCase): # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) +class TestTransformWhileLoopWithNone(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_dyfunc_with_none + + class TestWhileLoopBoolOp(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_bool_op