From 908132dd03fbf8179dfdfa848996e2a6f00da8be Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 19 Sep 2022 12:05:30 +0800 Subject: [PATCH] [Dy2Static] refactor the return transformer (#45900) * 1. refactor the return transformer. 2. fix some bugs in return transformer. * support raise error while return stmt's father is For or while * fix ci error. * fix ci error and add some unittest * code format * fix ci error --- .../create_variable_transformer.py | 1 + .../dygraph_to_static/return_transformer.py | 225 +++++++++--------- .../dygraph_to_static/test_convert_call.py | 4 +- .../dygraph_to_static/test_origin_info.py | 9 +- .../dygraph_to_static/test_return.py | 78 +++++- 5 files changed, 191 insertions(+), 126 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py index 8ae4c12eb8e..bcfa3e3ec1c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/create_variable_transformer.py @@ -41,6 +41,7 @@ class CreateVariableTransformer(BaseTransformer): def visit_FunctionDef(self, node): #attributes = set(filter(lambda x: '.' in x, node.pd_scope.modified_vars())) + self.generic_visit(node) bodys = node.body names = sorted(node.pd_scope.created_vars()) for name in names: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index ed2a739936e..f2503d09ce4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Fo from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException +from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO __all__ = [ 'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer' @@ -90,50 +92,37 @@ class ReturnAnalysisVisitor(gast.NodeVisitor): def __init__(self, root_node): self.root = root_node + assert isinstance( + self.root, gast.FunctionDef), "Input is not gast.FunctionDef node" - # A list to store where the current function is. - self.function_def = [] + # the number of return statements + self.count_return = 0 - # Mapping from gast.FunctionDef node to the number of return statements - # Python allows define function inside function so we have to handle it - self.count_return = {} + # maximum number of variables + self.max_return_length = 0 - # Mapping from gast.FunctionDef node to the maximum number of variables - # returned by the function's return statement - self.max_return_length = {} self.visit(self.root) def visit_FunctionDef(self, node): - self.function_def.append(node) - self.count_return[node] = 0 - self.max_return_length[node] = 0 - self.generic_visit(node) - self.function_def.pop() - return node + """ + don't analysis closure, just analyze current func def level. + """ + if node == self.root: + self.generic_visit(node) def visit_Return(self, node): - assert len( - self.function_def) > 0, "Found 'return' statement out of function." - cur_func = self.function_def[-1] - if cur_func in self.count_return: - self.count_return[cur_func] += 1 - else: - self.count_return[cur_func] = 1 + self.count_return += 1 return_length = get_return_size(node) - if cur_func in self.max_return_length: - self.max_return_length[cur_func] = max( - self.max_return_length[cur_func], return_length) - else: - self.max_return_length[cur_func] = return_length + self.max_return_length = max(self.max_return_length, return_length) self.generic_visit(node) - def get_func_return_count(self, func_node): - return self.count_return[func_node] + def get_func_return_count(self): + return self.count_return - def get_func_max_return_length(self, func_node): - return self.max_return_length[func_node] + def get_func_max_return_length(self): + return self.max_return_length class ReturnTransformer(BaseTransformer): @@ -143,32 +132,51 @@ class ReturnTransformer(BaseTransformer): variable to store the early return statements and boolean states with if-else to skip the statements after the return. + Go through all the function definition and call SingleReturnTransformer for each function. + SingleReturnTransformer don't care the nested function def. """ def __init__(self, wrapper_root): self.wrapper_root = wrapper_root self.root = wrapper_root.node - pre_transformer = ReplaceReturnNoneTransformer(self.root) pre_transformer.transform() + def transform(self): + self.visit(self.root) + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + node = SingleReturnTransformer(node).transform() + return node + + +class SingleReturnTransformer(BaseTransformer): + """ + This function only apply to single function. don't care the nested function_def + """ + + def __init__(self, root): + self.root = root + assert isinstance( + self.root, gast.FunctionDef), "Input is not gast.FunctionDef node" + self.ancestor_nodes = [] - # The name of the variable which stores the final return value - # Mapping from FunctionDef node to string - self.return_value_name = {} - # The names of the variable which stores the boolean state that skip - # statments. Mapping from FunctionDef node to list - self.return_name = {} - # The names of the variable which is placeholder to handle various- - # length return. Mapping from FunctionDef node to list - self.return_no_value_name = {} - # A list of FunctionDef to store where the current function is. - self.function_def = [] + + # The name of return placeholder + self.return_value_name = None + + # Every return stmt corresponds to a bool value variable, and return name is the name of the boolean variable + self.return_name = [] self.pre_analysis = None - def transform(self): - self.visit(self.root) + def assert_parent_is_not_while(self, parent_node_of_return): + if isinstance(parent_node_of_return, (gast.While, gast.For)): + raise Dygraph2StaticException( + "Found return statement in While or For body and loop " + "is meaningless, please check you code and remove return in while/for." + ) def generic_visit(self, node): # Because we change ancestor nodes during visit_Return, not current @@ -188,28 +196,46 @@ class ReturnTransformer(BaseTransformer): Self-defined visit for appending ancestor """ self.ancestor_nodes.append(node) - ret = super(ReturnTransformer, self).visit(node) + ret = super(SingleReturnTransformer, self).visit(node) self.ancestor_nodes.pop() return ret def visit_FunctionDef(self, node): - self.function_def.append(node) - self.return_value_name[node] = None - self.return_name[node] = [] - self.return_no_value_name[node] = [] + """ + don't analysis closure, just analyze current func def level. + """ + if node == self.root: + self.generic_visit(node) + return node + def append_assign_to_return_node(self, value, parent_node_of_return, + return_name, assign_nodes): + self.assert_parent_is_not_while(parent_node_of_return) + assert value in [True, False], "value must be True or False." + if isinstance(parent_node_of_return, gast.If): + # Prepend control flow boolean nodes such as '__return@1 = True' + node_str = "{} = _jst.create_bool_as_type({}, {})".format( + return_name, + ast_to_source_code(parent_node_of_return.test).strip(), value) + + assign_node = gast.parse(node_str).body[0] + assign_nodes.append(assign_node) + + def transform(self): + node = self.root self.pre_analysis = ReturnAnalysisVisitor(node) - max_return_length = self.pre_analysis.get_func_max_return_length(node) - while self.pre_analysis.get_func_return_count(node) > 1: - self.generic_visit(node) + max_return_length = self.pre_analysis.get_func_max_return_length() + while self.pre_analysis.get_func_return_count() > 0: + # every visit will decrease the number of returns. + # so we need a while. + self.visit(node) self.pre_analysis = ReturnAnalysisVisitor(node) if max_return_length == 0: - self.function_def.pop() return node # Prepend initialization of final return and append final return statement - value_name = self.return_value_name[node] + value_name = self.return_value_name if value_name is not None: node.body.append( gast.Return(value=gast.Name(id=value_name, @@ -227,54 +253,32 @@ class ReturnTransformer(BaseTransformer): node.body.insert(0, assign_return_value_node) # Prepend no value placeholders - self.function_def.pop() - - # Need update self.pre_analysis after pop - # For fix this case: - ''' - def fun(cond): - def inner(): - pass - if cond: - return True - else: - return False - ''' - if self.function_def: - self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1]) return node def visit_Return(self, node): - cur_func_node = self.function_def[-1] return_name = unique_name.generate(RETURN_PREFIX) - self.return_name[cur_func_node].append(return_name) - max_return_length = self.pre_analysis.get_func_max_return_length( - cur_func_node) + self.return_name.append(return_name) + max_return_length = self.pre_analysis.get_func_max_return_length() parent_node_of_return = self.ancestor_nodes[-2] for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): ancestor = self.ancestor_nodes[ancestor_index] cur_node = self.ancestor_nodes[ancestor_index + 1] - if hasattr(ancestor, - "body") and index_in_list(ancestor.body, cur_node) != -1: - if cur_node == node: - self._replace_return_in_stmt_list(ancestor.body, cur_node, - return_name, - max_return_length, - parent_node_of_return) - self._replace_after_node_to_if_in_stmt_list( - ancestor.body, cur_node, return_name, parent_node_of_return) - elif hasattr(ancestor, "orelse") and index_in_list( - ancestor.orelse, cur_node) != -1: - if cur_node == node: - self._replace_return_in_stmt_list(ancestor.orelse, cur_node, - return_name, - max_return_length, - parent_node_of_return) - self._replace_after_node_to_if_in_stmt_list( - ancestor.orelse, cur_node, return_name, - parent_node_of_return) + def _deal_branches(branch_name): + if hasattr(ancestor, branch_name): + branch_node = getattr(ancestor, branch_name) + if index_in_list(branch_node, cur_node) != -1: + if cur_node == node: + self._replace_return_in_stmt_list( + branch_node, cur_node, return_name, + max_return_length, parent_node_of_return) + self._replace_after_node_to_if_in_stmt_list( + branch_node, cur_node, return_name, + parent_node_of_return) + + _deal_branches("body") + _deal_branches("orelse") # If return node in while loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.While): cond_var_node = gast.UnaryOp(op=gast.Not(), @@ -302,7 +306,7 @@ class ReturnTransformer(BaseTransformer): while_node = new_stmts[-1] self.ancestor_nodes[ancestor_index] = while_node - if ancestor == cur_func_node: + if ancestor == self.root: break # return_node is replaced so we shouldn't return here @@ -315,34 +319,29 @@ class ReturnTransformer(BaseTransformer): return False assign_nodes = [] - # Here assume that the parent node of return is gast.If - if isinstance(parent_node_of_return, gast.If): - # Prepend control flow boolean nodes such as '__return@1 = True' - node_str = "{} = _jst.create_bool_as_type({}, True)".format( - return_name, - ast_to_source_code(parent_node_of_return.test).strip()) + self.append_assign_to_return_node(True, parent_node_of_return, + return_name, assign_nodes) - assign_true_node = gast.parse(node_str).body[0] - assign_nodes.append(assign_true_node) - - cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: - cur_func_node = self.function_def[-1] - if self.return_value_name[cur_func_node] is None: - self.return_value_name[cur_func_node] = unique_name.generate( + if self.return_value_name is None: + self.return_value_name = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign(targets=[ - gast.Name(id=self.return_value_name[cur_func_node], + gast.Name(id=self.return_value_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) + return_origin_info = getattr(return_node, ORIGI_INFO, None) + setattr(assign_nodes[-1], ORIGI_INFO, return_origin_info) + # If there is a return in the body or else of if, the remaining statements + # will not be executed, so they can be properly replaced. stmt_list[i:] = assign_nodes return True @@ -368,12 +367,8 @@ class ReturnTransformer(BaseTransformer): stmt_list[i + 1:] = [if_stmt] # Here assume that the parent node of return is gast.If - if isinstance(parent_node_of_return, gast.If): - # Prepend control flow boolean nodes such as '__return@1 = False' - node_str = "{} = _jst.create_bool_as_type({}, False)".format( - return_name, - ast_to_source_code(parent_node_of_return.test).strip()) - assign_false_node = gast.parse(node_str).body[0] - - stmt_list[i:i] = [assign_false_node] + assign_nodes = [] + self.append_assign_to_return_node(False, parent_node_of_return, + return_name, assign_nodes) + stmt_list[i:i] = assign_nodes return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py index 4f4d42c8092..48d7d3eb20c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py @@ -313,8 +313,10 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode): class StaticCode(): def func_convert_then_not_to_static(x): + __return_value_0 = None y = _jst.Call(func_not_to_static)(x) - return y + __return_value_0 = y + return __return_value_0 self.answer_func = StaticCode.func_convert_then_not_to_static diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py index b422164cf38..e7435922b1c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): self.func = simple_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 10, 11] + self.static_abs_lineno_list = [9, 11, 12] def set_dygraph_info(self): self.line_num = 3 @@ -93,7 +93,6 @@ class TestOriginInfo(unittest.TestCase): self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) info_map = create_and_update_origin_info_map(dygraph_ast, self.static_func) - return info_map def test_origin_info_map(self): @@ -149,7 +148,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): self.func = nested_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 11, 12, 13, 14] + self.static_abs_lineno_list = [9, 12, 14, 16, 17] def set_dygraph_info(self): self.line_num = 5 @@ -174,7 +173,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): self.func = decorated_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 10] + self.static_abs_lineno_list = [9, 11] def set_dygraph_info(self): self.line_num = 2 @@ -208,7 +207,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): self.func = decorated_func2 def set_static_lineno(self): - self.static_abs_lineno_list = [9, 10] + self.static_abs_lineno_list = [9, 11] def set_dygraph_info(self): self.line_num = 2 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py index 2905bd07439..748ff59534a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -224,6 +224,49 @@ def test_diff_return(x): return y, z +@to_static +def test_return_if_else_2(x): + rr = 0 + if True: + rr = 1 + return 1 + else: + a = 0 + + +@to_static +def test_return_in_while_2(x): + while True: + a = 12 + return 12 + return 10 + + +@to_static +def test_return_in_for_2(x): + a = 12 + for i in range(10): + return 12 + return 10 + + +@to_static +def test_return_nested(x): + + def func(): + rr = 0 + if True: + rr = 1 + return 1 + rr = 2 + else: + a = 0 + return 4 + return 3 + + return func() + + class TestReturnBase(unittest.TestCase): def setUp(self): @@ -256,7 +299,6 @@ class TestReturnBase(unittest.TestCase): np.testing.assert_allclose(dygraph_res[i], static_res[i], rtol=1e-05) - elif isinstance(dygraph_res, np.ndarray): np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) else: @@ -282,6 +324,24 @@ class TestReturnIf(TestReturnBase): self.dygraph_func = test_return_if +class TestReturnOnlyIf(TestReturnBase): + + def init_dygraph_func(self): + self.dygraph_func = test_return_if_else_2 + + +class TestReturnInFor(TestReturnBase): + + def init_dygraph_func(self): + self.dygraph_func = test_return_in_for + + +class TestReturnInWhile(TestReturnBase): + + def init_dygraph_func(self): + self.dygraph_func = test_return_in_while + + class TestReturnIfDiff(TestReturnBase): def init_dygraph_func(self): @@ -294,16 +354,18 @@ class TestReturnIfElse(TestReturnBase): self.dygraph_func = test_return_if_else -class TestReturnInWhile(TestReturnBase): +class TestReturnInWhile2(TestReturnBase): def init_dygraph_func(self): - self.dygraph_func = test_return_in_while + self.dygraph_func = test_return_in_while_2 + self.error = "Found return statement in While or For body and loop" -class TestReturnInFor(TestReturnBase): +class TestReturnInFor2(TestReturnBase): def init_dygraph_func(self): - self.dygraph_func = test_return_in_for + self.dygraph_func = test_return_in_for_2 + self.error = "Found return statement in While or For body and loop" class TestRecursiveReturn(TestReturnBase): @@ -371,6 +433,12 @@ class TestReturnTupleManyValue(TestReturnBase): self.dygraph_func = test_return_tuple_many_values +class TestReturnNested(TestReturnBase): + + def init_dygraph_func(self): + self.dygraph_func = test_return_nested + + class TestReturnSpecial(TestReturnBase): def init_dygraph_func(self): -- GitLab