未验证 提交 5ea82e8a 编写于 作者: L liym27 提交者: GitHub

[Dy2Static]Support return variable created in only one of If.body or If.orelse (#24841)

* Support return variable in only one of if body or else. 

* remove after_visit in IfElseTransformer.

* Modify the result of get_name_ids in test_ifelse_basic.py 

* Add unittest to test the new case. 

* Modify code according to reviews. 
上级 0494239b
......@@ -26,7 +26,6 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none
from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
......@@ -34,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
......@@ -55,14 +55,12 @@ class IfElseTransformer(gast.NodeTransformer):
wrapper_root)
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.new_func_nodes = {}
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
self.after_visit(self.root)
def visit_If(self, node):
if_condition_visitor = IfConditionVisitor(node.test,
......@@ -71,14 +69,14 @@ class IfElseTransformer(gast.NodeTransformer):
self.generic_visit(node)
if need_transform:
pred_node, new_assign_nodes = if_condition_visitor.transform()
true_func_node, false_func_node, return_name_ids = transform_if_else(
new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
# create layers.cond
new_node = create_cond_node(return_name_ids, pred_node,
cond_node = create_cond_node(return_name_ids, pred_node,
true_func_node, false_func_node)
self.new_func_nodes[new_node] = [true_func_node, false_func_node
] + new_assign_nodes
return new_node
return new_vars_stmts + [true_func_node, false_func_node
] + new_assign_nodes + [cond_node]
else:
return node
......@@ -117,43 +115,6 @@ class IfElseTransformer(gast.NodeTransformer):
else:
return node
def after_visit(self, node):
"""
This function will add some postprocessing operations with node.
It can be used to add the created `true_fn/false_fn` in front of
the node.body before they are called in cond layer.
"""
self._insert_func_nodes(node)
def _insert_func_nodes(self, node):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
Because private variables of class or other external scope will be modified.
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not self.new_func_nodes:
return
idx = -1
if isinstance(node, list):
idx = len(node) - 1
elif isinstance(node, gast.AST):
for _, child in gast.iter_fields(node):
self._insert_func_nodes(child)
while idx >= 0:
child_node = node[idx]
if child_node in self.new_func_nodes:
node[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
self._insert_func_nodes(child_node)
idx = idx - 1
def get_new_func_nodes(self):
return self.new_func_nodes
def merge_multi_assign_nodes(assign_nodes):
"""
......@@ -467,10 +428,6 @@ class NameVisitor(gast.NodeVisitor):
else:
self.name_ids = before_name_ids
def visit_Return(self, node):
# Ignore the vars in return
return
def _visit_child(self, node):
self.name_ids = defaultdict(list)
if isinstance(node, list):
......@@ -553,25 +510,63 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
return arguments
def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict):
def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
after_ifelse_vars_dict):
"""
Find out the ast.Name list of output by analyzing node's AST information.
Following conditions should be satisfied while determining whether a variable is a return value:
1. the var in parent scope is modified in if/else node.
2. new var is both created in if and else node.
One of the following conditions should be satisfied while determining whether a variable is a return value:
1. the var in parent scope is modified in If.body or If.orelse node.
2. new var is both created in If.body and If.orelse node.
3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
If different var is modified in if and else node, it should add the var in return_ids
of different node.
For example:
x, y = 5, 10
if x > 4:
x = x+1
z = x*x
q = 10
else:
y = y - 1
z = y*y
m = 20
n = 20
print(q)
n = 30
print(n)
The return_ids are (x, y, z, q) for `If.body` and `If.orelse`node, because
1. x is modified in If.body node,
2. y is modified in If.body node,
3. z is both created in If.body and If.orelse node,
4. q is created only in If.body, and it is used by `print(q)` as gast.Load.
Note:
After transformed, q and z are created in parent scope. For example,
x, y = 5, 10
q = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_func(x, y, q):
x = x+1
z = x*x
q = 10
return x,y,z,q
def false_func(x, y, q):
y = y - 1
z = y*y
m = 20
n = 20
return x,y,z,q
x,y,z,q = fluid.layers.cond(x>4, lambda: true_func(x, y), lambda: false_func(x, y, q))
m and n are not in return_ids, because
5. m is created only in If.orelse, but it is not used after gast.If node.
6. n is created only in If.orelse, and it is used by `n = 30` and `print(n)`, but it is not used as gast.Load firstly but gast.Store .
The return_ids should be (x, y, z) for `if` and `else`node.
"""
def _is_return_var(ctxs):
......@@ -587,57 +582,112 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict):
vars.append(k)
return vars
def _candidate_vars(child_dict, parent_dict):
def _modified_vars(child_dict, parent_dict):
return set([
var for var in _vars_with_store(child_dict) if var in parent_dict
])
# 1. the var in parent_ids is modified in if/else node.
if_candidate_vars = _candidate_vars(if_vars_dict, parent_vars_dict)
else_candidate_vars = _candidate_vars(else_vars_dict, parent_vars_dict)
# 2. new var is both created in if and else node.
if_new_vars = set([
def _vars_loaded_before_store(ids_dict):
new_dict = defaultdict(list)
for k, ctxs in ids_dict.items():
for ctx in ctxs:
if isinstance(ctx, gast.Load):
new_dict[k].append(ctx)
elif isinstance(ctx, gast.Store):
break
return new_dict
# modified vars
body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict)
orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict)
modified_vars = body_modified_vars | orelse_modified_vars
# new vars
body_new_vars = set([
var for var in _vars_with_store(if_vars_dict)
if var not in parent_vars_dict
])
else_new_vars = set([
orelse_new_vars = set([
var for var in _vars_with_store(else_vars_dict)
if var not in parent_vars_dict
])
new_vars = if_new_vars & else_new_vars
new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars
new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars
# 1. the var in parent scope is modified in If.body or If.orelse node.
modified_vars_from_parent = modified_vars - new_vars_in_body_or_orelse
# generate return_ids of if/else node.
modified_vars = if_candidate_vars | else_candidate_vars
return_ids = list(modified_vars | new_vars)
# 2. new var is both created in If.body and If.orelse node.
new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars
# 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
used_vars_after_ifelse = set(
[var for var in _vars_loaded_before_store(after_ifelse_vars_dict)])
new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse
# 4. generate return_ids of if/else node.
return_ids = list(modified_vars_from_parent | new_vars_in_body_and_orelse |
new_vars_to_create)
return_ids.sort()
return return_ids, list(modified_vars - new_vars)
return return_ids, modified_vars_from_parent, new_vars_to_create
def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
# TODO(liym27): Consider variable like `self.a` modified in if/else node.
parent_name_ids = get_name_ids([root], end_node=node)
if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse)
return_name_ids, modified_name_ids = parse_cond_return(
parent_name_ids, if_name_ids, else_name_ids)
body_name_ids = get_name_ids(node.body)
orelse_name_ids = get_name_ids(node.orelse)
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
after_ifelse_name_ids = defaultdict(list)
all_name_ids = get_name_ids([root])
for name in all_name_ids:
before_var_names_ids = parent_name_ids.get(name, []) + \
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids = [
ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids
]
if after_var_names_ids:
after_ifelse_name_ids[name] = after_var_names_ids
return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return(
parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids)
# NOTE: Python can create variable only in if body or only in else body, and use it out of if/else.
# E.g.
#
# if x > 5:
# a = 10
# print(a)
#
# Create static variable for those variables
create_new_vars_in_parent_stmts = []
for name in new_vars_to_create:
# NOTE: Consider variable like `self.a` modified in if/else node.
if "." not in name:
create_new_vars_in_parent_stmts.append(
create_static_variable_gast_node(name))
modified_name_ids = modified_name_ids_from_parent | new_vars_to_create
true_func_node = create_funcDef_node(
node.body,
name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=parse_cond_args(if_name_ids, modified_name_ids),
input_args=parse_cond_args(body_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
false_func_node = create_funcDef_node(
node.orelse,
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(else_name_ids, modified_name_ids),
input_args=parse_cond_args(orelse_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
return true_func_node, false_func_node, return_name_ids
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids,
......
......@@ -52,6 +52,51 @@ def dyfunc_with_if_else2(x, col=100):
return y
def dyfunc_with_if_else3(x):
# Create new var in parent scope, return it in true_fn and false_fn.
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code:
"""
q = fluid.dygraph.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = fluid.dygraph.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y):
x = x + 1
z = x + 2
q = x + 3
return q, x, y, z
def false_fn_0(q, x, y):
y = y + 1
z = x - 2
m = x + 2
n = x + 3
return q, x, y, z
q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda :
fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(q, x, y),
lambda : fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(q,
x, y))
"""
y = x + 1
# NOTE: x_v[0] < 5 is True
if fluid.layers.mean(x).numpy()[0] < 5:
x = x + 1
z = x + 2
q = x + 3
else:
y = y + 1
z = x - 2
m = x + 2
n = x + 3
q = q + 1
n = q + 2
x = n
return x
def nested_if_else(x_v):
batch_size = 16
feat_size = x_v.shape[-1]
......
......@@ -64,18 +64,24 @@ class TestDygraphIfElse2(TestDygraphIfElse):
class TestDygraphIfElse3(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_with_if_else3
class TestDygraphNestedIfElse(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else
class TestDygraphIfElse4(TestDygraphIfElse):
class TestDygraphNestedIfElse2(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_2
class TestDygraphIfElse5(TestDygraphIfElse):
class TestDygraphNestedIfElse3(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_3
......
......@@ -34,7 +34,7 @@ class TestGetNameIds(unittest.TestCase):
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param()]}
self.all_name_ids = {'x': [gast.Param(), gast.Load()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
......@@ -82,6 +82,7 @@ class TestGetNameIds2(TestGetNameIds):
gast.Load(),
gast.Store(),
gast.Store(),
gast.Load(),
]
}
......@@ -113,6 +114,7 @@ class TestGetNameIds3(TestGetNameIds):
gast.Store(),
gast.Load(),
gast.Store(),
gast.Load(),
]
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册