未验证 提交 908132dd 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] refactor the return transformer (#45900)

* 1. refactor the return transformer.
2. fix some bugs in return transformer.

* support raise error while return stmt's father is For or while

* fix ci error.

* fix ci error and add some unittest

* code format

* fix ci error
上级 82bbbe2c
......@@ -41,6 +41,7 @@ class CreateVariableTransformer(BaseTransformer):
def visit_FunctionDef(self, node):
#attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars()))
self.generic_visit(node)
bodys = node.body
names = sorted(node.pd_scope.created_vars())
for name in names:
......
......@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Fo
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO
__all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
......@@ -90,50 +92,37 @@ class ReturnAnalysisVisitor(gast.NodeVisitor):
def __init__(self, root_node):
self.root = root_node
assert isinstance(
self.root, gast.FunctionDef), "Input is not gast.FunctionDef node"
# A list to store where the current function is.
self.function_def = []
# the number of return statements
self.count_return = 0
# Mapping from gast.FunctionDef node to the number of return statements
# Python allows define function inside function so we have to handle it
self.count_return = {}
# maximum number of variables
self.max_return_length = 0
# Mapping from gast.FunctionDef node to the maximum number of variables
# returned by the function's return statement
self.max_return_length = {}
self.visit(self.root)
def visit_FunctionDef(self, node):
self.function_def.append(node)
self.count_return[node] = 0
self.max_return_length[node] = 0
self.generic_visit(node)
self.function_def.pop()
return node
"""
don't analysis closure, just analyze current func def level.
"""
if node == self.root:
self.generic_visit(node)
def visit_Return(self, node):
assert len(
self.function_def) > 0, "Found 'return' statement out of function."
cur_func = self.function_def[-1]
if cur_func in self.count_return:
self.count_return[cur_func] += 1
else:
self.count_return[cur_func] = 1
self.count_return += 1
return_length = get_return_size(node)
if cur_func in self.max_return_length:
self.max_return_length[cur_func] = max(
self.max_return_length[cur_func], return_length)
else:
self.max_return_length[cur_func] = return_length
self.max_return_length = max(self.max_return_length, return_length)
self.generic_visit(node)
def get_func_return_count(self, func_node):
return self.count_return[func_node]
def get_func_return_count(self):
return self.count_return
def get_func_max_return_length(self, func_node):
return self.max_return_length[func_node]
def get_func_max_return_length(self):
return self.max_return_length
class ReturnTransformer(BaseTransformer):
......@@ -143,32 +132,51 @@ class ReturnTransformer(BaseTransformer):
variable to store the early return statements and boolean states with
if-else to skip the statements after the return.
Go through all the function definition and call SingleReturnTransformer for each function.
SingleReturnTransformer don't care the nested function def.
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform()
def transform(self):
self.visit(self.root)
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
node = SingleReturnTransformer(node).transform()
return node
class SingleReturnTransformer(BaseTransformer):
"""
This function only apply to single function. don't care the nested function_def
"""
def __init__(self, root):
self.root = root
assert isinstance(
self.root, gast.FunctionDef), "Input is not gast.FunctionDef node"
self.ancestor_nodes = []
# The name of the variable which stores the final return value
# Mapping from FunctionDef node to string
self.return_value_name = {}
# The names of the variable which stores the boolean state that skip
# statments. Mapping from FunctionDef node to list
self.return_name = {}
# The names of the variable which is placeholder to handle various-
# length return. Mapping from FunctionDef node to list
self.return_no_value_name = {}
# A list of FunctionDef to store where the current function is.
self.function_def = []
# The name of return placeholder
self.return_value_name = None
# Every return stmt corresponds to a bool value variable, and return name is the name of the boolean variable
self.return_name = []
self.pre_analysis = None
def transform(self):
self.visit(self.root)
def assert_parent_is_not_while(self, parent_node_of_return):
if isinstance(parent_node_of_return, (gast.While, gast.For)):
raise Dygraph2StaticException(
"Found return statement in While or For body and loop "
"is meaningless, please check you code and remove return in while/for."
)
def generic_visit(self, node):
# Because we change ancestor nodes during visit_Return, not current
......@@ -188,28 +196,46 @@ class ReturnTransformer(BaseTransformer):
Self-defined visit for appending ancestor
"""
self.ancestor_nodes.append(node)
ret = super(ReturnTransformer, self).visit(node)
ret = super(SingleReturnTransformer, self).visit(node)
self.ancestor_nodes.pop()
return ret
def visit_FunctionDef(self, node):
self.function_def.append(node)
self.return_value_name[node] = None
self.return_name[node] = []
self.return_no_value_name[node] = []
"""
don't analysis closure, just analyze current func def level.
"""
if node == self.root:
self.generic_visit(node)
return node
def append_assign_to_return_node(self, value, parent_node_of_return,
return_name, assign_nodes):
self.assert_parent_is_not_while(parent_node_of_return)
assert value in [True, False], "value must be True or False."
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = _jst.create_bool_as_type({}, {})".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip(), value)
assign_node = gast.parse(node_str).body[0]
assign_nodes.append(assign_node)
def transform(self):
node = self.root
self.pre_analysis = ReturnAnalysisVisitor(node)
max_return_length = self.pre_analysis.get_func_max_return_length(node)
while self.pre_analysis.get_func_return_count(node) > 1:
self.generic_visit(node)
max_return_length = self.pre_analysis.get_func_max_return_length()
while self.pre_analysis.get_func_return_count() > 0:
# every visit will decrease the number of returns.
# so we need a while.
self.visit(node)
self.pre_analysis = ReturnAnalysisVisitor(node)
if max_return_length == 0:
self.function_def.pop()
return node
# Prepend initialization of final return and append final return statement
value_name = self.return_value_name[node]
value_name = self.return_value_name
if value_name is not None:
node.body.append(
gast.Return(value=gast.Name(id=value_name,
......@@ -227,54 +253,32 @@ class ReturnTransformer(BaseTransformer):
node.body.insert(0, assign_return_value_node)
# Prepend no value placeholders
self.function_def.pop()
# Need update self.pre_analysis after pop
# For fix this case:
'''
def fun(cond):
def inner():
pass
if cond:
return True
else:
return False
'''
if self.function_def:
self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1])
return node
def visit_Return(self, node):
cur_func_node = self.function_def[-1]
return_name = unique_name.generate(RETURN_PREFIX)
self.return_name[cur_func_node].append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length(
cur_func_node)
self.return_name.append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length()
parent_node_of_return = self.ancestor_nodes[-2]
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1]
if hasattr(ancestor,
"body") and index_in_list(ancestor.body, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.body, cur_node,
return_name,
max_return_length,
parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name, parent_node_of_return)
elif hasattr(ancestor, "orelse") and index_in_list(
ancestor.orelse, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
return_name,
max_return_length,
parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name,
parent_node_of_return)
def _deal_branches(branch_name):
if hasattr(ancestor, branch_name):
branch_node = getattr(ancestor, branch_name)
if index_in_list(branch_node, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(
branch_node, cur_node, return_name,
max_return_length, parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
branch_node, cur_node, return_name,
parent_node_of_return)
_deal_branches("body")
_deal_branches("orelse")
# If return node in while loop, add `not return_name` in gast.While.test
if isinstance(ancestor, gast.While):
cond_var_node = gast.UnaryOp(op=gast.Not(),
......@@ -302,7 +306,7 @@ class ReturnTransformer(BaseTransformer):
while_node = new_stmts[-1]
self.ancestor_nodes[ancestor_index] = while_node
if ancestor == cur_func_node:
if ancestor == self.root:
break
# return_node is replaced so we shouldn't return here
......@@ -315,34 +319,29 @@ class ReturnTransformer(BaseTransformer):
return False
assign_nodes = []
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = _jst.create_bool_as_type({}, True)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
self.append_assign_to_return_node(True, parent_node_of_return,
return_name, assign_nodes)
assign_true_node = gast.parse(node_str).body[0]
assign_nodes.append(assign_true_node)
cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node)
# In this case we should NOT append RETURN_NO_VALUE placeholder
if return_node.value is not None:
cur_func_node = self.function_def[-1]
if self.return_value_name[cur_func_node] is None:
self.return_value_name[cur_func_node] = unique_name.generate(
if self.return_value_name is None:
self.return_value_name = unique_name.generate(
RETURN_VALUE_PREFIX)
assign_nodes.append(
gast.Assign(targets=[
gast.Name(id=self.return_value_name[cur_func_node],
gast.Name(id=self.return_value_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=return_node.value))
return_origin_info = getattr(return_node, ORIGI_INFO, None)
setattr(assign_nodes[-1], ORIGI_INFO, return_origin_info)
# If there is a return in the body or else of if, the remaining statements
# will not be executed, so they can be properly replaced.
stmt_list[i:] = assign_nodes
return True
......@@ -368,12 +367,8 @@ class ReturnTransformer(BaseTransformer):
stmt_list[i + 1:] = [if_stmt]
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = False'
node_str = "{} = _jst.create_bool_as_type({}, False)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0]
stmt_list[i:i] = [assign_false_node]
assign_nodes = []
self.append_assign_to_return_node(False, parent_node_of_return,
return_name, assign_nodes)
stmt_list[i:i] = assign_nodes
return True
......@@ -313,8 +313,10 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class StaticCode():
def func_convert_then_not_to_static(x):
__return_value_0 = None
y = _jst.Call(func_not_to_static)(x)
return y
__return_value_0 = y
return __return_value_0
self.answer_func = StaticCode.func_convert_then_not_to_static
......
......@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func
def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10, 11]
self.static_abs_lineno_list = [9, 11, 12]
def set_dygraph_info(self):
self.line_num = 3
......@@ -93,7 +93,6 @@ class TestOriginInfo(unittest.TestCase):
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_and_update_origin_info_map(dygraph_ast,
self.static_func)
return info_map
def test_origin_info_map(self):
......@@ -149,7 +148,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func
def set_static_lineno(self):
self.static_abs_lineno_list = [9, 11, 12, 13, 14]
self.static_abs_lineno_list = [9, 12, 14, 16, 17]
def set_dygraph_info(self):
self.line_num = 5
......@@ -174,7 +173,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func
def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10]
self.static_abs_lineno_list = [9, 11]
def set_dygraph_info(self):
self.line_num = 2
......@@ -208,7 +207,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2
def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10]
self.static_abs_lineno_list = [9, 11]
def set_dygraph_info(self):
self.line_num = 2
......
......@@ -224,6 +224,49 @@ def test_diff_return(x):
return y, z
@to_static
def test_return_if_else_2(x):
rr = 0
if True:
rr = 1
return 1
else:
a = 0
@to_static
def test_return_in_while_2(x):
while True:
a = 12
return 12
return 10
@to_static
def test_return_in_for_2(x):
a = 12
for i in range(10):
return 12
return 10
@to_static
def test_return_nested(x):
def func():
rr = 0
if True:
rr = 1
return 1
rr = 2
else:
a = 0
return 4
return 3
return func()
class TestReturnBase(unittest.TestCase):
def setUp(self):
......@@ -256,7 +299,6 @@ class TestReturnBase(unittest.TestCase):
np.testing.assert_allclose(dygraph_res[i],
static_res[i],
rtol=1e-05)
elif isinstance(dygraph_res, np.ndarray):
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
else:
......@@ -282,6 +324,24 @@ class TestReturnIf(TestReturnBase):
self.dygraph_func = test_return_if
class TestReturnOnlyIf(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_if_else_2
class TestReturnInFor(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_for
class TestReturnInWhile(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_while
class TestReturnIfDiff(TestReturnBase):
def init_dygraph_func(self):
......@@ -294,16 +354,18 @@ class TestReturnIfElse(TestReturnBase):
self.dygraph_func = test_return_if_else
class TestReturnInWhile(TestReturnBase):
class TestReturnInWhile2(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_while
self.dygraph_func = test_return_in_while_2
self.error = "Found return statement in While or For body and loop"
class TestReturnInFor(TestReturnBase):
class TestReturnInFor2(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_for
self.dygraph_func = test_return_in_for_2
self.error = "Found return statement in While or For body and loop"
class TestRecursiveReturn(TestReturnBase):
......@@ -371,6 +433,12 @@ class TestReturnTupleManyValue(TestReturnBase):
self.dygraph_func = test_return_tuple_many_values
class TestReturnNested(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_nested
class TestReturnSpecial(TestReturnBase):
def init_dygraph_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册