未验证 提交 96126532 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Incorrect After Node Vars in IfElseTransformer, test=develop (#28992)

The PR description is long. See details in the PR link.
上级 982fd0f3
......@@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer):
class NameVisitor(gast.NodeVisitor):
def __init__(self, end_node=None):
def __init__(self, after_node=None, end_node=None):
# The start node (exclusive) of the visitor
self.after_node = after_node
# The terminate node of the visitor.
self.end_node = end_node
# Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list)
# List of current visited nodes
self.ancestor_nodes = []
# Available only when end_node is set.
self._is_finished = False
# True when in range (after_node, end_node).
self._in_range = after_node is None
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
self._def_func_names = set()
def visit(self, node):
"""Visit a node."""
if node == self.end_node or self._is_finished:
self._is_finished = True
if self.after_node is not None and node == self.after_node:
self._in_range = True
return
if node == self.end_node:
self._in_range = False
return
self.ancestor_nodes.append(node)
......@@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor):
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
if not self.end_node:
if not self._in_range or not self.end_node:
self.generic_visit(node)
return
else:
before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body)
# If traversal process stops early in `if.body`, return the currently seen name_ids.
if self._is_finished:
if not self._in_range:
self._update_name_ids(before_if_name_ids)
else:
else_name_ids = self._visit_child(node.orelse)
# If traversal process stops early in `if.orelse`, return the currently seen name_ids.
if self._is_finished:
if not self._in_range:
self._update_name_ids(before_if_name_ids)
else:
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
......@@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids = before_if_name_ids
def visit_Attribute(self, node):
if not self._is_call_func_name_node(node):
if not self._in_range or not self._is_call_func_name_node(node):
self.generic_visit(node)
def visit_Name(self, node):
if not self._in_range:
self.generic_visit(node)
return
blacklist = {'True', 'False', 'None'}
if node.id in blacklist: return
if node.id in self._def_func_names:
......@@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids[node.id].append(node.ctx)
def visit_Assign(self, node):
if not self._in_range:
self.generic_visit(node)
return
# Visit `value` firstly.
node._fields = ('value', 'targets')
self.generic_visit(node)
def visit_FunctionDef(self, node):
if not self._in_range:
self.generic_visit(node)
return
self._def_func_names.add(node.name)
if not self.end_node:
self.generic_visit(node)
......@@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids = defaultdict(list)
self.generic_visit(node)
if self._is_finished:
if not self._in_range:
self._update_name_ids(before_name_ids)
else:
self.name_ids = before_name_ids
......@@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor):
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
def get_name_ids(nodes, end_node=None):
def get_name_ids(nodes, after_node=None, end_node=None):
"""
Return all ast.Name.id of python variable in nodes.
Return all ast.Name.id of python variable in nodes range from
(after_node, end_node) exclusively. If after_node or end_node is None, the
range is unlimited.
"""
name_visitor = NameVisitor(end_node)
name_visitor = NameVisitor(after_node, end_node)
for node in nodes:
name_visitor.visit(node)
return name_visitor.name_ids
......@@ -434,20 +451,8 @@ def transform_if_else(node, root):
parent_name_ids = get_name_ids([root], end_node=node)
body_name_ids = get_name_ids(node.body)
orelse_name_ids = get_name_ids(node.orelse)
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
after_ifelse_name_ids = defaultdict(list)
all_name_ids = get_name_ids([root])
for name in all_name_ids:
before_var_names_ids = parent_name_ids.get(name, []) + \
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids = [
ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids
]
if after_var_names_ids:
after_ifelse_name_ids[name] = after_var_names_ids
after_ifelse_name_ids = get_name_ids([root], after_node=node)
return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return(
parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids)
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np
import unittest
import paddle
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
......@@ -271,5 +272,75 @@ class TestNetWithExternalFunc(TestDygraphIfElseNet):
self.Net = NetWithExternalFunc
class DiffModeNet1(paddle.nn.Layer):
def __init__(self, mode):
super(DiffModeNet1, self).__init__()
self.mode = mode
@paddle.jit.to_static
def forward(self, x, y):
if self.mode == 'train':
out = x + y
elif self.mode == 'infer':
out = x - y
else:
raise ValueError('Illegal mode')
return out
class DiffModeNet2(paddle.nn.Layer):
def __init__(self, mode):
super(DiffModeNet2, self).__init__()
self.mode = mode
@paddle.jit.to_static
def forward(self, x, y):
if self.mode == 'train':
out = x + y
return out
elif self.mode == 'infer':
out = x - y
return out
else:
raise ValueError('Illegal mode')
class TestDiffModeNet(unittest.TestCase):
"""
TestCase for the net with different modes
"""
def setUp(self):
self.x = paddle.randn([10, 16], 'float32')
self.y = paddle.randn([10, 16], 'float32')
self.init_net()
def init_net(self):
self.Net = DiffModeNet1
def _run(self, mode, to_static):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
net = self.Net(mode)
ret = net(self.x, self.y)
return ret.numpy()
def test_train_mode(self):
self.assertTrue((self._run(
mode='train', to_static=True) == self._run(
mode='train', to_static=False)).all())
def test_infer_mode(self):
self.assertTrue((self._run(
mode='infer', to_static=True) == self._run(
mode='infer', to_static=False)).all())
class TestDiffModeNet2(TestDiffModeNet):
def init_net(self):
self.Net = DiffModeNet2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册