未验证 提交 4282af69 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] refactor the return transformer (#45900) (#46205)

* 1. refactor the return transformer.
2. fix some bugs in return transformer.

* support raise error while return stmt's father is For or while

* fix ci error.

* fix ci error and add some unittest

* code format

* fix ci error
上级 a58663f3
...@@ -41,6 +41,7 @@ class CreateVariableTransformer(BaseTransformer): ...@@ -41,6 +41,7 @@ class CreateVariableTransformer(BaseTransformer):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
#attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars())) #attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars()))
self.generic_visit(node)
bodys = node.body bodys = node.body
names = sorted(node.pd_scope.created_vars()) names = sorted(node.pd_scope.created_vars())
for name in names: for name in names:
......
...@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Fo ...@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Fo
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO
__all__ = [ __all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer' 'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
...@@ -90,50 +92,37 @@ class ReturnAnalysisVisitor(gast.NodeVisitor): ...@@ -90,50 +92,37 @@ class ReturnAnalysisVisitor(gast.NodeVisitor):
def __init__(self, root_node): def __init__(self, root_node):
self.root = root_node self.root = root_node
assert isinstance(
self.root, gast.FunctionDef), "Input is not gast.FunctionDef node"
# A list to store where the current function is. # the number of return statements
self.function_def = [] self.count_return = 0
# Mapping from gast.FunctionDef node to the number of return statements # maximum number of variables
# Python allows define function inside function so we have to handle it self.max_return_length = 0
self.count_return = {}
# Mapping from gast.FunctionDef node to the maximum number of variables
# returned by the function's return statement
self.max_return_length = {}
self.visit(self.root) self.visit(self.root)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.function_def.append(node) """
self.count_return[node] = 0 don't analysis closure, just analyze current func def level.
self.max_return_length[node] = 0 """
if node == self.root:
self.generic_visit(node) self.generic_visit(node)
self.function_def.pop()
return node
def visit_Return(self, node): def visit_Return(self, node):
assert len( self.count_return += 1
self.function_def) > 0, "Found 'return' statement out of function."
cur_func = self.function_def[-1]
if cur_func in self.count_return:
self.count_return[cur_func] += 1
else:
self.count_return[cur_func] = 1
return_length = get_return_size(node) return_length = get_return_size(node)
if cur_func in self.max_return_length: self.max_return_length = max(self.max_return_length, return_length)
self.max_return_length[cur_func] = max(
self.max_return_length[cur_func], return_length)
else:
self.max_return_length[cur_func] = return_length
self.generic_visit(node) self.generic_visit(node)
def get_func_return_count(self, func_node): def get_func_return_count(self):
return self.count_return[func_node] return self.count_return
def get_func_max_return_length(self, func_node): def get_func_max_return_length(self):
return self.max_return_length[func_node] return self.max_return_length
class ReturnTransformer(BaseTransformer): class ReturnTransformer(BaseTransformer):
...@@ -143,32 +132,51 @@ class ReturnTransformer(BaseTransformer): ...@@ -143,32 +132,51 @@ class ReturnTransformer(BaseTransformer):
variable to store the early return statements and boolean states with variable to store the early return statements and boolean states with
if-else to skip the statements after the return. if-else to skip the statements after the return.
Go through all the function definition and call SingleReturnTransformer for each function.
SingleReturnTransformer don't care the nested function def.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
pre_transformer = ReplaceReturnNoneTransformer(self.root) pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform() pre_transformer.transform()
def transform(self):
self.visit(self.root)
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
node = SingleReturnTransformer(node).transform()
return node
class SingleReturnTransformer(BaseTransformer):
"""
This function only apply to single function. don't care the nested function_def
"""
def __init__(self, root):
self.root = root
assert isinstance(
self.root, gast.FunctionDef), "Input is not gast.FunctionDef node"
self.ancestor_nodes = [] self.ancestor_nodes = []
# The name of the variable which stores the final return value
# Mapping from FunctionDef node to string # The name of return placeholder
self.return_value_name = {} self.return_value_name = None
# The names of the variable which stores the boolean state that skip
# statments. Mapping from FunctionDef node to list # Every return stmt corresponds to a bool value variable, and return name is the name of the boolean variable
self.return_name = {} self.return_name = []
# The names of the variable which is placeholder to handle various-
# length return. Mapping from FunctionDef node to list
self.return_no_value_name = {}
# A list of FunctionDef to store where the current function is.
self.function_def = []
self.pre_analysis = None self.pre_analysis = None
def transform(self): def assert_parent_is_not_while(self, parent_node_of_return):
self.visit(self.root) if isinstance(parent_node_of_return, (gast.While, gast.For)):
raise Dygraph2StaticException(
"Found return statement in While or For body and loop "
"is meaningless, please check you code and remove return in while/for."
)
def generic_visit(self, node): def generic_visit(self, node):
# Because we change ancestor nodes during visit_Return, not current # Because we change ancestor nodes during visit_Return, not current
...@@ -188,28 +196,46 @@ class ReturnTransformer(BaseTransformer): ...@@ -188,28 +196,46 @@ class ReturnTransformer(BaseTransformer):
Self-defined visit for appending ancestor Self-defined visit for appending ancestor
""" """
self.ancestor_nodes.append(node) self.ancestor_nodes.append(node)
ret = super(ReturnTransformer, self).visit(node) ret = super(SingleReturnTransformer, self).visit(node)
self.ancestor_nodes.pop() self.ancestor_nodes.pop()
return ret return ret
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.function_def.append(node) """
self.return_value_name[node] = None don't analysis closure, just analyze current func def level.
self.return_name[node] = [] """
self.return_no_value_name[node] = [] if node == self.root:
self.generic_visit(node)
return node
def append_assign_to_return_node(self, value, parent_node_of_return,
return_name, assign_nodes):
self.assert_parent_is_not_while(parent_node_of_return)
assert value in [True, False], "value must be True or False."
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = _jst.create_bool_as_type({}, {})".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip(), value)
assign_node = gast.parse(node_str).body[0]
assign_nodes.append(assign_node)
def transform(self):
node = self.root
self.pre_analysis = ReturnAnalysisVisitor(node) self.pre_analysis = ReturnAnalysisVisitor(node)
max_return_length = self.pre_analysis.get_func_max_return_length(node) max_return_length = self.pre_analysis.get_func_max_return_length()
while self.pre_analysis.get_func_return_count(node) > 1: while self.pre_analysis.get_func_return_count() > 0:
self.generic_visit(node) # every visit will decrease the number of returns.
# so we need a while.
self.visit(node)
self.pre_analysis = ReturnAnalysisVisitor(node) self.pre_analysis = ReturnAnalysisVisitor(node)
if max_return_length == 0: if max_return_length == 0:
self.function_def.pop()
return node return node
# Prepend initialization of final return and append final return statement # Prepend initialization of final return and append final return statement
value_name = self.return_value_name[node] value_name = self.return_value_name
if value_name is not None: if value_name is not None:
node.body.append( node.body.append(
gast.Return(value=gast.Name(id=value_name, gast.Return(value=gast.Name(id=value_name,
...@@ -225,56 +251,32 @@ class ReturnTransformer(BaseTransformer): ...@@ -225,56 +251,32 @@ class ReturnTransformer(BaseTransformer):
value=gast.Constant( value=gast.Constant(
kind=None, value=None)) kind=None, value=None))
node.body.insert(0, assign_return_value_node) node.body.insert(0, assign_return_value_node)
# Prepend no value placeholders
self.function_def.pop()
# Need update self.pre_analysis after pop
# For fix this case:
'''
def fun(cond):
def inner():
pass
if cond:
return True
else:
return False
'''
if self.function_def:
self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1])
return node return node
def visit_Return(self, node): def visit_Return(self, node):
cur_func_node = self.function_def[-1]
return_name = unique_name.generate(RETURN_PREFIX) return_name = unique_name.generate(RETURN_PREFIX)
self.return_name[cur_func_node].append(return_name) self.return_name.append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length( max_return_length = self.pre_analysis.get_func_max_return_length()
cur_func_node)
parent_node_of_return = self.ancestor_nodes[-2] parent_node_of_return = self.ancestor_nodes[-2]
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index] ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1] cur_node = self.ancestor_nodes[ancestor_index + 1]
if hasattr(ancestor,
"body") and index_in_list(ancestor.body, cur_node) != -1: def _deal_branches(branch_name):
if cur_node == node: if hasattr(ancestor, branch_name):
self._replace_return_in_stmt_list(ancestor.body, cur_node, branch_node = getattr(ancestor, branch_name)
return_name, if index_in_list(branch_node, cur_node) != -1:
max_return_length,
parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name, parent_node_of_return)
elif hasattr(ancestor, "orelse") and index_in_list(
ancestor.orelse, cur_node) != -1:
if cur_node == node: if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node, self._replace_return_in_stmt_list(
return_name, branch_node, cur_node, return_name,
max_return_length, max_return_length, parent_node_of_return)
parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list( self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name, branch_node, cur_node, return_name,
parent_node_of_return) parent_node_of_return)
_deal_branches("body")
_deal_branches("orelse")
# If return node in while loop, add `not return_name` in gast.While.test # If return node in while loop, add `not return_name` in gast.While.test
if isinstance(ancestor, gast.While): if isinstance(ancestor, gast.While):
cond_var_node = gast.UnaryOp(op=gast.Not(), cond_var_node = gast.UnaryOp(op=gast.Not(),
...@@ -302,7 +304,7 @@ class ReturnTransformer(BaseTransformer): ...@@ -302,7 +304,7 @@ class ReturnTransformer(BaseTransformer):
while_node = new_stmts[-1] while_node = new_stmts[-1]
self.ancestor_nodes[ancestor_index] = while_node self.ancestor_nodes[ancestor_index] = while_node
if ancestor == cur_func_node: if ancestor == self.root:
break break
# return_node is replaced so we shouldn't return here # return_node is replaced so we shouldn't return here
...@@ -315,34 +317,29 @@ class ReturnTransformer(BaseTransformer): ...@@ -315,34 +317,29 @@ class ReturnTransformer(BaseTransformer):
return False return False
assign_nodes = [] assign_nodes = []
# Here assume that the parent node of return is gast.If self.append_assign_to_return_node(True, parent_node_of_return,
if isinstance(parent_node_of_return, gast.If): return_name, assign_nodes)
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = _jst.create_bool_as_type({}, True)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_true_node = gast.parse(node_str).body[0]
assign_nodes.append(assign_true_node)
cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node) return_length = get_return_size(return_node)
# In this case we should NOT append RETURN_NO_VALUE placeholder # In this case we should NOT append RETURN_NO_VALUE placeholder
if return_node.value is not None: if return_node.value is not None:
cur_func_node = self.function_def[-1] if self.return_value_name is None:
if self.return_value_name[cur_func_node] is None: self.return_value_name = unique_name.generate(
self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX) RETURN_VALUE_PREFIX)
assign_nodes.append( assign_nodes.append(
gast.Assign(targets=[ gast.Assign(targets=[
gast.Name(id=self.return_value_name[cur_func_node], gast.Name(id=self.return_value_name,
ctx=gast.Store(), ctx=gast.Store(),
annotation=None, annotation=None,
type_comment=None) type_comment=None)
], ],
value=return_node.value)) value=return_node.value))
return_origin_info = getattr(return_node, ORIGI_INFO, None)
setattr(assign_nodes[-1], ORIGI_INFO, return_origin_info)
# If there is a return in the body or else of if, the remaining statements
# will not be executed, so they can be properly replaced.
stmt_list[i:] = assign_nodes stmt_list[i:] = assign_nodes
return True return True
...@@ -368,12 +365,8 @@ class ReturnTransformer(BaseTransformer): ...@@ -368,12 +365,8 @@ class ReturnTransformer(BaseTransformer):
stmt_list[i + 1:] = [if_stmt] stmt_list[i + 1:] = [if_stmt]
# Here assume that the parent node of return is gast.If # Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If): assign_nodes = []
# Prepend control flow boolean nodes such as '__return@1 = False' self.append_assign_to_return_node(False, parent_node_of_return,
node_str = "{} = _jst.create_bool_as_type({}, False)".format( return_name, assign_nodes)
return_name, stmt_list[i:i] = assign_nodes
ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0]
stmt_list[i:i] = [assign_false_node]
return True return True
...@@ -313,8 +313,10 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode): ...@@ -313,8 +313,10 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class StaticCode(): class StaticCode():
def func_convert_then_not_to_static(x): def func_convert_then_not_to_static(x):
__return_value_0 = None
y = _jst.Call(func_not_to_static)(x) y = _jst.Call(func_not_to_static)(x)
return y __return_value_0 = y
return __return_value_0
self.answer_func = StaticCode.func_convert_then_not_to_static self.answer_func = StaticCode.func_convert_then_not_to_static
......
...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func self.func = simple_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10, 11] self.static_abs_lineno_list = [9, 11, 12]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 3 self.line_num = 3
...@@ -93,7 +93,6 @@ class TestOriginInfo(unittest.TestCase): ...@@ -93,7 +93,6 @@ class TestOriginInfo(unittest.TestCase):
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_and_update_origin_info_map(dygraph_ast, info_map = create_and_update_origin_info_map(dygraph_ast,
self.static_func) self.static_func)
return info_map return info_map
def test_origin_info_map(self): def test_origin_info_map(self):
...@@ -149,7 +148,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): ...@@ -149,7 +148,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func self.func = nested_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [9, 11, 12, 13, 14] self.static_abs_lineno_list = [9, 12, 14, 16, 17]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 5 self.line_num = 5
...@@ -174,7 +173,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): ...@@ -174,7 +173,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func self.func = decorated_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10] self.static_abs_lineno_list = [9, 11]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
...@@ -208,7 +207,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): ...@@ -208,7 +207,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2 self.func = decorated_func2
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [9, 10] self.static_abs_lineno_list = [9, 11]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
......
...@@ -224,6 +224,49 @@ def test_diff_return(x): ...@@ -224,6 +224,49 @@ def test_diff_return(x):
return y, z return y, z
@to_static
def test_return_if_else_2(x):
rr = 0
if True:
rr = 1
return 1
else:
a = 0
@to_static
def test_return_in_while_2(x):
while True:
a = 12
return 12
return 10
@to_static
def test_return_in_for_2(x):
a = 12
for i in range(10):
return 12
return 10
@to_static
def test_return_nested(x):
def func():
rr = 0
if True:
rr = 1
return 1
rr = 2
else:
a = 0
return 4
return 3
return func()
class TestReturnBase(unittest.TestCase): class TestReturnBase(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -256,7 +299,6 @@ class TestReturnBase(unittest.TestCase): ...@@ -256,7 +299,6 @@ class TestReturnBase(unittest.TestCase):
np.testing.assert_allclose(dygraph_res[i], np.testing.assert_allclose(dygraph_res[i],
static_res[i], static_res[i],
rtol=1e-05) rtol=1e-05)
elif isinstance(dygraph_res, np.ndarray): elif isinstance(dygraph_res, np.ndarray):
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
else: else:
...@@ -282,6 +324,24 @@ class TestReturnIf(TestReturnBase): ...@@ -282,6 +324,24 @@ class TestReturnIf(TestReturnBase):
self.dygraph_func = test_return_if self.dygraph_func = test_return_if
class TestReturnOnlyIf(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_if_else_2
class TestReturnInFor(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_for
class TestReturnInWhile(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_while
class TestReturnIfDiff(TestReturnBase): class TestReturnIfDiff(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
...@@ -294,16 +354,18 @@ class TestReturnIfElse(TestReturnBase): ...@@ -294,16 +354,18 @@ class TestReturnIfElse(TestReturnBase):
self.dygraph_func = test_return_if_else self.dygraph_func = test_return_if_else
class TestReturnInWhile(TestReturnBase): class TestReturnInWhile2(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_in_while self.dygraph_func = test_return_in_while_2
self.error = "Found return statement in While or For body and loop"
class TestReturnInFor(TestReturnBase): class TestReturnInFor2(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_in_for self.dygraph_func = test_return_in_for_2
self.error = "Found return statement in While or For body and loop"
class TestRecursiveReturn(TestReturnBase): class TestRecursiveReturn(TestReturnBase):
...@@ -371,6 +433,12 @@ class TestReturnTupleManyValue(TestReturnBase): ...@@ -371,6 +433,12 @@ class TestReturnTupleManyValue(TestReturnBase):
self.dygraph_func = test_return_tuple_many_values self.dygraph_func = test_return_tuple_many_values
class TestReturnNested(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_nested
class TestReturnSpecial(TestReturnBase): class TestReturnSpecial(TestReturnBase):
def init_dygraph_func(self): def init_dygraph_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册