未验证 提交 2403362d 编写于 作者: A Aurelius84 提交者: GitHub

BugFix for parsing Arguments and inserting funcs in IfElseTransormer (#23035)

* Support and/or in controlFlow if test=develop
上级 01ab8a06
......@@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer):
"""
self._insert_func_nodes(node)
def _insert_func_nodes(self, parent_node):
def _insert_func_nodes(self, node):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
......@@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer):
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
if not self.new_func_nodes:
return
idx = len(parent_node.body) - 1
idx = -1
if isinstance(node, list):
idx = len(node) - 1
elif isinstance(node, gast.AST):
for _, child in gast.iter_fields(node):
self._insert_func_nodes(child)
while idx >= 0:
child_node = parent_node.body[idx]
child_node = node[idx]
if child_node in self.new_func_nodes:
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
node[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
......@@ -366,51 +371,133 @@ class IfConditionVisitor(object):
return new_node, new_assign_nodes
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
class NameVisitor(gast.NodeVisitor):
def __init__(self, node_black_set=None):
# Set of nodes that will not be visited.
self.node_black_set = node_black_set or set()
# Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list)
# List of current visited nodes
self.ancestor_nodes = []
# Available only when node_black_set is set.
self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
def visit(self, node):
"""Visit a node."""
if node in self.node_black_set or self._is_finished:
self._is_finished = True
return
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_If(self, node):
"""
Return all ast.Name.id of python variable in nodes.
For nested `if/else`, the created vars are not always visible for parent node.
In addition, the vars created in `if.body` are not visible for `if.orelse`.
Case 1:
x = 1
if m > 1:
res = new_tensor
res = res + 1 # Error, `res` is not visible here.
Case 2:
if x_tensor > 0:
res = new_tensor
else:
res = res + 1 # Error, `res` is not visible here.
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
if not isinstance(nodes, (list, tuple, set)):
raise ValueError(
"nodes must be one of list, tuple, set, but received %s" %
type(nodes))
if not_name_set is None:
not_name_set = set()
def update(old_dict, new_dict):
for k, v in new_dict.items():
old_dict[k].extend(v)
name_ids = defaultdict(list)
for node in nodes:
if node_black_list and node in node_black_list:
break
if isinstance(node, gast.AST):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
# 2. api prefix like `fluid` of `fluid.layers.mean`
if isinstance(node, gast.Return):
continue
elif isinstance(node, gast.Call) and isinstance(node.func,
gast.Name):
not_name_set.add(node.func.id)
elif isinstance(node, gast.Attribute) and isinstance(node.value,
gast.Name):
not_name_set.add(node.value.id)
if isinstance(
node, gast.Name
) and node.id not in name_ids and node.id not in not_name_set:
if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)):
name_ids[node.id].append(node.ctx)
before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body)
# If the traversal process stops early, just return the name_ids that have been seen.
if self._is_finished:
for name_id, ctxs in before_if_name_ids.items():
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
else:
if isinstance(node, gast.Assign):
node = copy.copy(node)
else_name_ids = self._visit_child(node.orelse)
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids)
for new_name_id in new_name_ids:
before_if_name_ids[new_name_id].append(gast.Store())
self.name_ids = before_if_name_ids
def visit_Attribute(self, node):
if not self._is_call_func_name_node(node):
self.generic_visit(node)
def visit_Name(self, node):
if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx)
def visit_Assign(self, node):
# Visit `value` firstly.
node._fields = ('value', 'targets')
for field, value in gast.iter_fields(node):
value = value if isinstance(value, list) else [value]
update(name_ids,
get_name_ids(value, not_name_set, node_black_list))
return name_ids
self.generic_visit(node)
def visit_Return(self, node):
# Ignore the vars in return
return
def _visit_child(self, node):
self.name_ids = defaultdict(list)
if isinstance(node, list):
for item in node:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(node, gast.AST):
self.visit(node)
return copy.deepcopy(self.name_ids)
def _find_new_name_ids(self, body_name_ids, else_name_ids):
def is_required_ctx(ctxs, required_ctx):
for ctx in ctxs:
if isinstance(ctx, required_ctx):
return True
return False
candidate_name_ids = set(body_name_ids.keys()) & set(else_name_ids.keys(
))
store_ctx = gast.Store
new_name_ids = set()
for name_id in candidate_name_ids:
if is_required_ctx(body_name_ids[name_id],
store_ctx) and is_required_ctx(
else_name_ids[name_id], store_ctx):
new_name_ids.add(name_id)
return new_name_ids
def _is_call_func_name_node(self, node):
if len(self.ancestor_nodes) > 1:
assert self.ancestor_nodes[-1] == node
parent_node = self.ancestor_nodes[-2]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
def get_name_ids(nodes, node_black_set=None):
"""
Return all ast.Name.id of python variable in nodes.
"""
name_visitor = NameVisitor(node_black_set)
for node in nodes:
name_visitor.visit(node)
return name_visitor.name_ids
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
......@@ -508,7 +595,7 @@ def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids = get_name_ids([root], node_black_list=[node])
parent_name_ids = get_name_ids([root], node_black_set=[node])
if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse)
......
......@@ -65,6 +65,58 @@ def nested_if_else(x_v):
return y
def nested_if_else_2(x):
y = fluid.layers.reshape(x, [-1, 1])
b = 2
if b < 1:
# var `z` is not visible for outer scope
z = y
x_shape_0 = x.shape[0]
if x_shape_0 < 1:
if fluid.layers.shape(y).numpy()[0] < 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
# `z` is a new var here.
z = y + 1
else:
res = fluid.layers.fill_constant(
value=3, shape=x.shape, dtype="int32")
else:
res = x
return res
def nested_if_else_3(x):
y = fluid.layers.reshape(x, [-1, 1])
b = 2
# var `z` is visible for func.body
if b < 1:
z = y
else:
z = x
if b < 1:
res = x
# var `out` is only visible for current `if`
if b > 1:
out = x + 1
else:
out = x - 1
else:
y_shape = fluid.layers.shape(y)
if y_shape.numpy()[0] < 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
# `z` is created in above code block.
z = y + 1
else:
res = fluid.layers.fill_constant(
value=3, shape=x.shape, dtype="int32")
# `out` is a new var.
out = x + 1
return res
class NetWithControlFlowIf(fluid.dygraph.Layer):
def __init__(self, hidden_dim=16):
super(NetWithControlFlowIf, self).__init__()
......
......@@ -22,7 +22,7 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from test_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
class TestAST2Func(unittest.TestCase):
......
......@@ -72,6 +72,18 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self.dyfunc = nested_if_else
class TestDygraphIfElse4(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_2
class TestDygraphIfElse5(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_3
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
......@@ -65,14 +65,10 @@ class TestGetNameIds2(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y':
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
'x': [gast.Param(), gast.Store()],
'a': [gast.Store(), gast.Load()],
'y': [gast.Param(), gast.Load()],
'z': [gast.Store()]
}
......@@ -87,9 +83,9 @@ class TestGetNameIds3(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
'y': [gast.Param(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
'x': [gast.Param()],
'y': [gast.Param()],
'z': [gast.Store()]
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册