未验证 提交 2403362d 编写于 作者: A Aurelius84 提交者: GitHub

BugFix for parsing Arguments and inserting funcs in IfElseTransormer (#23035)

* Support and/or in controlFlow if test=develop
上级 01ab8a06
......@@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer):
"""
self._insert_func_nodes(node)
def _insert_func_nodes(self, parent_node):
def _insert_func_nodes(self, node):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
......@@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer):
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
if not self.new_func_nodes:
return
idx = len(parent_node.body) - 1
idx = -1
if isinstance(node, list):
idx = len(node) - 1
elif isinstance(node, gast.AST):
for _, child in gast.iter_fields(node):
self._insert_func_nodes(child)
while idx >= 0:
child_node = parent_node.body[idx]
child_node = node[idx]
if child_node in self.new_func_nodes:
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
node[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
......@@ -366,51 +371,133 @@ class IfConditionVisitor(object):
return new_node, new_assign_nodes
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
class NameVisitor(gast.NodeVisitor):
def __init__(self, node_black_set=None):
# Set of nodes that will not be visited.
self.node_black_set = node_black_set or set()
# Dict to store the names and ctxs of vars.
self.name_ids = defaultdict(list)
# List of current visited nodes
self.ancestor_nodes = []
# Available only when node_black_set is set.
self._is_finished = False
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
def visit(self, node):
"""Visit a node."""
if node in self.node_black_set or self._is_finished:
self._is_finished = True
return
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_If(self, node):
"""
For nested `if/else`, the created vars are not always visible for parent node.
In addition, the vars created in `if.body` are not visible for `if.orelse`.
Case 1:
x = 1
if m > 1:
res = new_tensor
res = res + 1 # Error, `res` is not visible here.
Case 2:
if x_tensor > 0:
res = new_tensor
else:
res = res + 1 # Error, `res` is not visible here.
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
before_if_name_ids = copy.deepcopy(self.name_ids)
body_name_ids = self._visit_child(node.body)
# If the traversal process stops early, just return the name_ids that have been seen.
if self._is_finished:
for name_id, ctxs in before_if_name_ids.items():
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
else:
else_name_ids = self._visit_child(node.orelse)
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids)
for new_name_id in new_name_ids:
before_if_name_ids[new_name_id].append(gast.Store())
self.name_ids = before_if_name_ids
def visit_Attribute(self, node):
if not self._is_call_func_name_node(node):
self.generic_visit(node)
def visit_Name(self, node):
if not self._is_call_func_name_node(node):
if isinstance(node.ctx, self._candidate_ctxs):
self.name_ids[node.id].append(node.ctx)
def visit_Assign(self, node):
# Visit `value` firstly.
node._fields = ('value', 'targets')
self.generic_visit(node)
def visit_Return(self, node):
# Ignore the vars in return
return
def _visit_child(self, node):
self.name_ids = defaultdict(list)
if isinstance(node, list):
for item in node:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(node, gast.AST):
self.visit(node)
return copy.deepcopy(self.name_ids)
def _find_new_name_ids(self, body_name_ids, else_name_ids):
def is_required_ctx(ctxs, required_ctx):
for ctx in ctxs:
if isinstance(ctx, required_ctx):
return True
return False
candidate_name_ids = set(body_name_ids.keys()) & set(else_name_ids.keys(
))
store_ctx = gast.Store
new_name_ids = set()
for name_id in candidate_name_ids:
if is_required_ctx(body_name_ids[name_id],
store_ctx) and is_required_ctx(
else_name_ids[name_id], store_ctx):
new_name_ids.add(name_id)
return new_name_ids
def _is_call_func_name_node(self, node):
if len(self.ancestor_nodes) > 1:
assert self.ancestor_nodes[-1] == node
parent_node = self.ancestor_nodes[-2]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
def get_name_ids(nodes, node_black_set=None):
"""
Return all ast.Name.id of python variable in nodes.
"""
if not isinstance(nodes, (list, tuple, set)):
raise ValueError(
"nodes must be one of list, tuple, set, but received %s" %
type(nodes))
if not_name_set is None:
not_name_set = set()
def update(old_dict, new_dict):
for k, v in new_dict.items():
old_dict[k].extend(v)
name_ids = defaultdict(list)
name_visitor = NameVisitor(node_black_set)
for node in nodes:
if node_black_list and node in node_black_list:
break
if isinstance(node, gast.AST):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
# 2. api prefix like `fluid` of `fluid.layers.mean`
if isinstance(node, gast.Return):
continue
elif isinstance(node, gast.Call) and isinstance(node.func,
gast.Name):
not_name_set.add(node.func.id)
elif isinstance(node, gast.Attribute) and isinstance(node.value,
gast.Name):
not_name_set.add(node.value.id)
if isinstance(
node, gast.Name
) and node.id not in name_ids and node.id not in not_name_set:
if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)):
name_ids[node.id].append(node.ctx)
else:
if isinstance(node, gast.Assign):
node = copy.copy(node)
node._fields = ('value', 'targets')
for field, value in gast.iter_fields(node):
value = value if isinstance(value, list) else [value]
update(name_ids,
get_name_ids(value, not_name_set, node_black_list))
return name_ids
name_visitor.visit(node)
return name_visitor.name_ids
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
......@@ -508,7 +595,7 @@ def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids = get_name_ids([root], node_black_list=[node])
parent_name_ids = get_name_ids([root], node_black_set=[node])
if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse)
......
......@@ -65,6 +65,58 @@ def nested_if_else(x_v):
return y
def nested_if_else_2(x):
y = fluid.layers.reshape(x, [-1, 1])
b = 2
if b < 1:
# var `z` is not visible for outer scope
z = y
x_shape_0 = x.shape[0]
if x_shape_0 < 1:
if fluid.layers.shape(y).numpy()[0] < 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
# `z` is a new var here.
z = y + 1
else:
res = fluid.layers.fill_constant(
value=3, shape=x.shape, dtype="int32")
else:
res = x
return res
def nested_if_else_3(x):
y = fluid.layers.reshape(x, [-1, 1])
b = 2
# var `z` is visible for func.body
if b < 1:
z = y
else:
z = x
if b < 1:
res = x
# var `out` is only visible for current `if`
if b > 1:
out = x + 1
else:
out = x - 1
else:
y_shape = fluid.layers.shape(y)
if y_shape.numpy()[0] < 1:
res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32")
# `z` is created in above code block.
z = y + 1
else:
res = fluid.layers.fill_constant(
value=3, shape=x.shape, dtype="int32")
# `out` is a new var.
out = x + 1
return res
class NetWithControlFlowIf(fluid.dygraph.Layer):
def __init__(self, hidden_dim=16):
super(NetWithControlFlowIf, self).__init__()
......
......@@ -22,7 +22,7 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from test_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
class TestAST2Func(unittest.TestCase):
......
......@@ -72,6 +72,18 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self.dyfunc = nested_if_else
class TestDygraphIfElse4(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_2
class TestDygraphIfElse5(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_3
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
......@@ -65,14 +65,10 @@ class TestGetNameIds2(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y':
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
'x': [gast.Param(), gast.Store()],
'a': [gast.Store(), gast.Load()],
'y': [gast.Param(), gast.Load()],
'z': [gast.Store()]
}
......@@ -87,9 +83,9 @@ class TestGetNameIds3(TestGetNameIds):
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
'y': [gast.Param(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
'x': [gast.Param()],
'y': [gast.Param()],
'z': [gast.Store()]
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册