未验证 提交 af926306 编写于 作者: L liym27 提交者: GitHub

fix bug of loop_vars in loop_transformer.test=develop (#23180)

上级 ebe4eab9
...@@ -169,7 +169,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -169,7 +169,7 @@ class NameVisitor(gast.NodeVisitor):
if self._is_call_func_name_node(node): if self._is_call_func_name_node(node):
self.generic_visit(node) self.generic_visit(node)
return return
if node.id == "False" or node.id == "True": if node.id == "False" or node.id == "True" or node.id == "None":
self.generic_visit(node) self.generic_visit(node)
return return
...@@ -187,7 +187,6 @@ class NameVisitor(gast.NodeVisitor): ...@@ -187,7 +187,6 @@ class NameVisitor(gast.NodeVisitor):
def visit_Attribute(self, node): def visit_Attribute(self, node):
if self._is_call_func_name_node(node): if self._is_call_func_name_node(node):
return return
attr_full_name = get_attribute_full_name(node) attr_full_name = get_attribute_full_name(node)
self.current_seen_vars.add(node) self.current_seen_vars.add(node)
for loop_node in self.current_loop: for loop_node in self.current_loop:
......
...@@ -35,6 +35,17 @@ def while_loop_dyfunc(x): ...@@ -35,6 +35,17 @@ def while_loop_dyfunc(x):
return i return i
def while_loop_dyfunc_with_none(x):
i = fluid.dygraph.to_variable(x)\
if x is not None \
else fluid.dygraph.to_variable(x+1)
flag = 1
while x < 10:
i = i + x if flag is not None else x + i
x = x + 1
return i
def for_loop_dyfunc(max_len): def for_loop_dyfunc(max_len):
for i in range(max_len): for i in range(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32') ret = fluid.layers.zeros(shape=[1], dtype='float32')
...@@ -58,9 +69,14 @@ def var_create_in_for_loop(max_len): ...@@ -58,9 +69,14 @@ def var_create_in_for_loop(max_len):
class TestNameVisitor(unittest.TestCase): class TestNameVisitor(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] self.loop_funcs = [
self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])] while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none
self.create_var_names = [set(), set(["ret"])] ]
self.loop_var_names = [
set(["i", "x"]), set(["i", "ret", "max_len"]),
set(["i", "x", "flag"])
]
self.create_var_names = [set(), set(["ret"]), set()]
def test_loop_vars(self): def test_loop_vars(self):
for i in range(len(self.loop_funcs)): for i in range(len(self.loop_funcs)):
...@@ -115,6 +131,11 @@ class TestTransformWhileLoop(unittest.TestCase): ...@@ -115,6 +131,11 @@ class TestTransformWhileLoop(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) # self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfunc_with_none
class TestWhileLoopBoolOp(TestTransformWhileLoop): class TestWhileLoopBoolOp(TestTransformWhileLoop):
def _init_dyfunc(self): def _init_dyfunc(self):
self.dyfunc = while_loop_bool_op self.dyfunc = while_loop_bool_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册