未验证 提交 96126532 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Incorrect After Node Vars in IfElseTransformer, test=develop (#28992)

The PR description is long. See details in the PR link.
上级 982fd0f3
...@@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer):
class NameVisitor(gast.NodeVisitor): class NameVisitor(gast.NodeVisitor):
def __init__(self, end_node=None): def __init__(self, after_node=None, end_node=None):
# The start node (exclusive) of the visitor
self.after_node = after_node
# The terminate node of the visitor. # The terminate node of the visitor.
self.end_node = end_node self.end_node = end_node
# Dict to store the names and ctxs of vars. # Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list) self.name_ids = defaultdict(list)
# List of current visited nodes # List of current visited nodes
self.ancestor_nodes = [] self.ancestor_nodes = []
# Available only when end_node is set. # True when in range (after_node, end_node).
self._is_finished = False self._in_range = after_node is None
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
self._def_func_names = set() self._def_func_names = set()
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
if node == self.end_node or self._is_finished: if self.after_node is not None and node == self.after_node:
self._is_finished = True self._in_range = True
return
if node == self.end_node:
self._in_range = False
return return
self.ancestor_nodes.append(node) self.ancestor_nodes.append(node)
...@@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor): ...@@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor):
In above two cases, we should consider to manage the scope of vars to parsing In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly. the arguments and returned vars correctly.
""" """
if not self.end_node: if not self._in_range or not self.end_node:
self.generic_visit(node) self.generic_visit(node)
return
else: else:
before_if_name_ids = copy.deepcopy(self.name_ids) before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body) body_name_ids = self._visit_child(node.body)
# If traversal process stops early in `if.body`, return the currently seen name_ids. # If traversal process stops early in `if.body`, return the currently seen name_ids.
if self._is_finished: if not self._in_range:
self._update_name_ids(before_if_name_ids) self._update_name_ids(before_if_name_ids)
else: else:
else_name_ids = self._visit_child(node.orelse) else_name_ids = self._visit_child(node.orelse)
# If traversal process stops early in `if.orelse`, return the currently seen name_ids. # If traversal process stops early in `if.orelse`, return the currently seen name_ids.
if self._is_finished: if not self._in_range:
self._update_name_ids(before_if_name_ids) self._update_name_ids(before_if_name_ids)
else: else:
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
...@@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor): ...@@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids = before_if_name_ids self.name_ids = before_if_name_ids
def visit_Attribute(self, node): def visit_Attribute(self, node):
if not self._is_call_func_name_node(node): if not self._in_range or not self._is_call_func_name_node(node):
self.generic_visit(node) self.generic_visit(node)
def visit_Name(self, node): def visit_Name(self, node):
if not self._in_range:
self.generic_visit(node)
return
blacklist = {'True', 'False', 'None'} blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return if node.id in blacklist: return
if node.id in self._def_func_names: if node.id in self._def_func_names:
...@@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor): ...@@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids[node.id].append(node.ctx) self.name_ids[node.id].append(node.ctx)
def visit_Assign(self, node): def visit_Assign(self, node):
if not self._in_range:
self.generic_visit(node)
return
# Visit `value` firstly. # Visit `value` firstly.
node._fields = ('value', 'targets') node._fields = ('value', 'targets')
self.generic_visit(node) self.generic_visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if not self._in_range:
self.generic_visit(node)
return
self._def_func_names.add(node.name) self._def_func_names.add(node.name)
if not self.end_node: if not self.end_node:
self.generic_visit(node) self.generic_visit(node)
...@@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids = defaultdict(list) self.name_ids = defaultdict(list)
self.generic_visit(node) self.generic_visit(node)
if self._is_finished: if not self._in_range:
self._update_name_ids(before_name_ids) self._update_name_ids(before_name_ids)
else: else:
self.name_ids = before_name_ids self.name_ids = before_name_ids
...@@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor): ...@@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids[name_id] = ctxs + self.name_ids[name_id] self.name_ids[name_id] = ctxs + self.name_ids[name_id]
def get_name_ids(nodes, end_node=None): def get_name_ids(nodes, after_node=None, end_node=None):
""" """
Return all ast.Name.id of python variable in nodes. Return all ast.Name.id of python variable in nodes range from
(after_node, end_node) exclusively. If after_node or end_node is None, the
range is unlimited.
""" """
name_visitor = NameVisitor(end_node) name_visitor = NameVisitor(after_node, end_node)
for node in nodes: for node in nodes:
name_visitor.visit(node) name_visitor.visit(node)
return name_visitor.name_ids return name_visitor.name_ids
...@@ -434,20 +451,8 @@ def transform_if_else(node, root): ...@@ -434,20 +451,8 @@ def transform_if_else(node, root):
parent_name_ids = get_name_ids([root], end_node=node) parent_name_ids = get_name_ids([root], end_node=node)
body_name_ids = get_name_ids(node.body) body_name_ids = get_name_ids(node.body)
orelse_name_ids = get_name_ids(node.orelse) orelse_name_ids = get_name_ids(node.orelse)
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node. # Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
after_ifelse_name_ids = defaultdict(list) after_ifelse_name_ids = get_name_ids([root], after_node=node)
all_name_ids = get_name_ids([root])
for name in all_name_ids:
before_var_names_ids = parent_name_ids.get(name, []) + \
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids = [
ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids
]
if after_var_names_ids:
after_ifelse_name_ids[name] = after_var_names_ids
return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return( return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return(
parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids) parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids)
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import unittest import unittest
import paddle
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
...@@ -271,5 +272,75 @@ class TestNetWithExternalFunc(TestDygraphIfElseNet): ...@@ -271,5 +272,75 @@ class TestNetWithExternalFunc(TestDygraphIfElseNet):
self.Net = NetWithExternalFunc self.Net = NetWithExternalFunc
class DiffModeNet1(paddle.nn.Layer):
def __init__(self, mode):
super(DiffModeNet1, self).__init__()
self.mode = mode
@paddle.jit.to_static
def forward(self, x, y):
if self.mode == 'train':
out = x + y
elif self.mode == 'infer':
out = x - y
else:
raise ValueError('Illegal mode')
return out
class DiffModeNet2(paddle.nn.Layer):
def __init__(self, mode):
super(DiffModeNet2, self).__init__()
self.mode = mode
@paddle.jit.to_static
def forward(self, x, y):
if self.mode == 'train':
out = x + y
return out
elif self.mode == 'infer':
out = x - y
return out
else:
raise ValueError('Illegal mode')
class TestDiffModeNet(unittest.TestCase):
"""
TestCase for the net with different modes
"""
def setUp(self):
self.x = paddle.randn([10, 16], 'float32')
self.y = paddle.randn([10, 16], 'float32')
self.init_net()
def init_net(self):
self.Net = DiffModeNet1
def _run(self, mode, to_static):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
net = self.Net(mode)
ret = net(self.x, self.y)
return ret.numpy()
def test_train_mode(self):
self.assertTrue((self._run(
mode='train', to_static=True) == self._run(
mode='train', to_static=False)).all())
def test_infer_mode(self):
self.assertTrue((self._run(
mode='infer', to_static=True) == self._run(
mode='infer', to_static=False)).all())
class TestDiffModeNet2(TestDiffModeNet):
def init_net(self):
self.Net = DiffModeNet2
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册