diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 0815f61432f189bb8153ce692e68c3194206b55e..9d0eec7f779fb59afeb4d0a8553f3f0b1e72c306 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -29,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer +from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor @@ -71,6 +72,9 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform break/continue in loops BreakContinueTransformer(node_wrapper).transform() + # Transform return in functions + ReturnTransformer(node_wrapper).transform() + # Transform logical and/or/not LogicalTransformer(node_wrapper).transform() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index 9b25ff07ec4c06f8b61ea31d5f824df7108df922..c78f6e8f403196fc098914c4cc58c8a16a4d885c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -49,7 +49,7 @@ class ForToWhileTransformer(gast.NodeTransformer): new_stmts = self.get_for_stmt_nodes(body_list[i]) body_list[i:i + 1] = new_stmts i += len(new_stmts) - return + return new_stmts if hasattr(self.parent_node, 'orelse'): body_list = self.parent_node.orelse i = index_in_list(body_list, self.loop_node) @@ -57,7 +57,7 @@ class ForToWhileTransformer(gast.NodeTransformer): new_stmts = self.get_for_stmt_nodes(body_list[i]) body_list[i:i + 1] = new_stmts i += len(new_stmts) - return + return new_stmts raise ValueError( "parent_node doesn't contain the loop_node in ForToWhileTransformer") diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5125a190fb9341d130d8ebf92a87e18de2e300ea --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -0,0 +1,247 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast + +from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list +from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node + +__all__ = ['ReturnTransformer'] + +# Constant for the name of the variable which stores the boolean state that we +# should return +RETURN_PREFIX = '__return' + +# Constant for the name of the variable which stores the final return value +RETURN_VALUE_PREFIX = '__return_value' + + +class ReturnPreAnalysisVisitor(gast.NodeVisitor): + """ + Visits gast Tree and pre-analyze the information about 'return'. + """ + + def __init__(self, root_node): + self.root = root_node + + # A list to store where the current function is. + self.function_def = [] + + # Mapping from gast.FunctionDef node to the number of return statements + # Python allows define function inside function so we have to handle it + self.count_return = {} + self.visit(self.root) + + def visit_FunctionDef(self, node): + self.function_def.append(node) + self.count_return[node] = 0 + self.generic_visit(node) + self.function_def.pop() + return node + + def visit_Return(self, node): + assert len( + self.function_def) > 0, "Found 'return' statement out of function." + cur_func = self.function_def[-1] + if cur_func in self.count_return: + self.count_return[cur_func] += 1 + else: + self.count_return[cur_func] = 1 + self.generic_visit(node) + + def get_func_return_count(self, func_node): + return self.count_return[func_node] + + def set_func_return_count(self, func_node, count): + self.count_return[func_node] = count + + +class ReturnTransformer(gast.NodeTransformer): + """ + Transforms return statements into equivalent python statements containing + only one return statement at last. The basics idea is using a return value + variable to store the early return statements and boolean states with + if-else to skip the statements after the return. + + """ + + def __init__(self, wrapper_root): + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + self.ancestor_nodes = [] + + # The name of the variable which stores the final return value + # Mapping from FunctionDef node to string + self.return_value_name = {} + # The names of the variable which stores the boolean state that skip + # statments. Mapping from FunctionDef node to list + self.return_name = {} + # A list of FunctionDef to store where the current function is. + self.function_def = [] + + def transform(self): + self.visit(self.root) + + def generic_visit(self, node): + # Because we change ancestor nodes during visit_Return, not current + # node, original generic_visit of NodeTransformer will visit node + # which may be deleted. To prevent that node being added into + # transformed AST, We self-write a generic_visit and visit + for field, value in gast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, gast.AST): + self.visit(item) + elif isinstance(value, gast.AST): + self.visit(value) + + def visit(self, node): + """ + Self-defined visit for appending ancestor + """ + self.ancestor_nodes.append(node) + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + ret = visitor(node) + self.ancestor_nodes.pop() + return ret + + def visit_FunctionDef(self, node): + self.function_def.append(node) + self.return_value_name[node] = None + self.return_name[node] = [] + + pre_analysis = ReturnPreAnalysisVisitor(node) + while pre_analysis.get_func_return_count(node) > 1: + self.generic_visit(node) + pre_analysis = ReturnPreAnalysisVisitor(node) + + # prepend initialization of final return and append final return statement + value_name = self.return_value_name[node] + if value_name is not None: + node.body.append( + gast.Return(value=gast.Name( + id=value_name, + ctx=gast.Load(), + annotation=None, + type_comment=None))) + assign_zero_node = create_fill_constant_node(value_name, 0.0) + node.body.insert(0, assign_zero_node) + # Prepend control flow boolean nodes such as '__return@1 = False' + for name in self.return_name[node]: + assign_false_node = create_fill_constant_node(name, False) + node.body.insert(0, assign_false_node) + + self.function_def.pop() + return node + + def visit_Return(self, node): + cur_func_node = self.function_def[-1] + return_name = unique_name.generate(RETURN_PREFIX) + self.return_name[cur_func_node].append(return_name) + for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): + ancestor = self.ancestor_nodes[ancestor_index] + cur_node = self.ancestor_nodes[ancestor_index + 1] + if hasattr(ancestor, + "body") and index_in_list(ancestor.body, cur_node) != -1: + if cur_node == node: + self._replace_return_in_stmt_list(ancestor.body, cur_node, + return_name) + self._replace_after_node_to_if_in_stmt_list( + ancestor.body, cur_node, return_name) + elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, + cur_node) != -1: + if cur_node == node: + self._replace_return_in_stmt_list(ancestor.orelse, cur_node, + return_name) + self._replace_after_node_to_if_in_stmt_list( + ancestor.orelse, cur_node, return_name) + + if isinstance(ancestor, gast.While): + cond_var_node = gast.UnaryOp( + op=gast.Not(), + operand=gast.Name( + id=return_name, + ctx=gast.Load(), + annotation=None, + type_comment=None)) + ancestor.test = gast.BoolOp( + op=gast.And(), values=[ancestor.test, cond_var_node]) + continue + + if isinstance(ancestor, gast.For): + cond_var_node = gast.UnaryOp( + op=gast.Not(), + operand=gast.Name( + id=return_name, + ctx=gast.Load(), + annotation=None, + type_comment=None)) + parent_node = self.ancestor_nodes[ancestor_index - 1] + for_to_while = ForToWhileTransformer(parent_node, ancestor, + cond_var_node) + new_stmts = for_to_while.transform() + while_node = new_stmts[-1] + self.ancestor_nodes[ancestor_index] = while_node + + if ancestor == cur_func_node: + break + # return_node is replaced so we shouldn't return here + + def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name): + i = index_in_list(stmt_list, return_node) + if i == -1: + return False + assign_nodes = [create_fill_constant_node(return_name, True)] + if return_node.value is not None: + cur_func_node = self.function_def[-1] + if self.return_value_name[cur_func_node] is None: + self.return_value_name[cur_func_node] = unique_name.generate( + RETURN_VALUE_PREFIX) + assign_nodes.append( + gast.Assign( + targets=[ + gast.Name( + id=self.return_value_name[cur_func_node], + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=return_node.value)) + stmt_list[i:] = assign_nodes + return True + + def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, + return_name): + i = index_in_list(stmt_list, node) + if i < 0 or i >= len(stmt_list): + return False + if i == len(stmt_list) - 1: + # No need to add, we consider this as added successfully + return True + if_stmt = gast.If(test=gast.UnaryOp( + op=gast.Not(), + operand=gast.Name( + id=return_name, + ctx=gast.Store(), + annotation=None, + type_comment=None)), + body=stmt_list[i + 1:], + orelse=[]) + stmt_list[i + 1:] = [if_stmt] + return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_dygraph_model.py index df0afe8b5647ca45c870a9e40e0122a78764c858..b302dd37794fd05f6ca9ca76694070708a3d9549 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_dygraph_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_dygraph_model.py @@ -247,8 +247,11 @@ class BertModelLayer(Layer): enc_output = self._encoder(emb_out, n_head_self_attn_mask) - if not self.return_pooled_out: - return enc_output + # TODO(zhhsplendid): uncomment this in next PR which we support various + # length of early return + # + #if not self.return_pooled_out: + # return enc_output next_sent_feat = fluid.layers.slice( input=enc_output, axes=[1], starts=[0], ends=[1]) next_sent_feat = self.pooled_fc(next_sent_feat) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 3cf8f5b71d7760e9cfea11049f02e07ee31a8087..1cba65bc8e48fd44b41a72486b78e17a456d7f1a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -63,6 +63,13 @@ def get_source_code(func): class StaticCode1(): # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): + __return_1 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=False) + __return_0 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=False) + __return_value_0 = fluid.layers.fill_constant( + shape=[1], dtype='float64', value=0.0) + def true_fn_0(x_v): x_v = x_v - 1 return x_v @@ -75,45 +82,94 @@ class StaticCode1(): fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), (x_v, ), (x_v, )) - def true_fn_1(label, x_v): + def true_fn_1(__return_0, __return_value_0, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - return loss - return - - def false_fn_1(): - return - - fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - label is not None, true_fn_1, false_fn_1, (label, x_v), (), ()) - return x_v + __return_0 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=True) + __return_value_0 = loss + return __return_0, __return_value_0 + + def false_fn_1(__return_0, __return_value_0): + return __return_0, __return_value_0 + + __return_0, __return_value_0 = ( + fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( + label is not None, true_fn_1, false_fn_1, + (__return_0, __return_value_0, label, x_v), + (__return_0, __return_value_0), (__return_0, __return_value_0))) + + def true_fn_2(__return_1, __return_value_0, x_v): + __return_1 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=True) + __return_value_0 = x_v + return __return_1, __return_value_0 + + def false_fn_2(__return_1, __return_value_0): + return __return_1, __return_value_0 + + __return_1, __return_value_0 = ( + fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( + fluid.dygraph.dygraph_to_static.convert_operators. + convert_logical_not(__return_0), true_fn_2, false_fn_2, + (__return_1, __return_value_0, x_v), + (__return_1, __return_value_0), (__return_1, __return_value_0))) + return __return_value_0 class StaticCode2(): # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): - def true_fn_2(x_v): + __return_3 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=False) + __return_2 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=False) + __return_value_1 = fluid.layers.fill_constant( + shape=[1], dtype='float64', value=0.0) + + def true_fn_3(x_v): x_v = x_v - 1 return x_v - def false_fn_2(x_v): + def false_fn_3(x_v): x_v = x_v + 1 return x_v x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ), + fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), (x_v, ), (x_v, )) - def true_fn_3(label, x_v): + def true_fn_4(__return_2, __return_value_1, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - return loss - return - - def false_fn_3(): - return - - fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - label is not None, true_fn_3, false_fn_3, (label, x_v), (), ()) - return x_v + __return_2 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=True) + __return_value_1 = loss + return __return_2, __return_value_1 + + def false_fn_4(__return_2, __return_value_1): + return __return_2, __return_value_1 + + __return_2, __return_value_1 = ( + fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( + label is not None, true_fn_4, false_fn_4, + (__return_2, __return_value_1, label, x_v), + (__return_2, __return_value_1), (__return_2, __return_value_1))) + + def true_fn_5(__return_3, __return_value_1, x_v): + __return_3 = fluid.layers.fill_constant( + shape=[1], dtype='bool', value=True) + __return_value_1 = x_v + return __return_3, __return_value_1 + + def false_fn_5(__return_3, __return_value_1): + return __return_3, __return_value_1 + + __return_3, __return_value_1 = ( + fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( + fluid.dygraph.dygraph_to_static.convert_operators. + convert_logical_not(__return_2), true_fn_5, false_fn_5, + (__return_3, __return_value_1, x_v), + (__return_3, __return_value_1), (__return_3, __return_value_1))) + return __return_value_1 class NetWithError(fluid.dygraph.layers.Layer): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py new file mode 100644 index 0000000000000000000000000000000000000000..59ea532832e8637b469841b823302cd76f037d5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph import declarative +from paddle.fluid.dygraph import ProgramTranslator + +from ifelse_simple_func import dyfunc_with_if_else + +SEED = 2020 +np.random.seed(SEED) + + +@declarative +def test_return_base(x): + x = fluid.dygraph.to_variable(x) + return x + + +@declarative +def test_inside_func_base(x): + x = fluid.dygraph.to_variable(x) + + def inner_func(x): + return x + + return inner_func(x) + + +@declarative +def test_return_if(x): + x = fluid.dygraph.to_variable(x) + if x < 0: + x -= 1 + return -x + x += 3 + return x + + +@declarative +def test_return_if_else(x): + x = fluid.dygraph.to_variable(x) + if x > 0: + x += 10086 + return x + x -= 3 # useless statement to test our code can handle it. + else: + x += 6666 + return x + x -= 8888 # useless statement to test our code can handle it. + + +@declarative +def test_return_in_while(x): + x = fluid.dygraph.to_variable(x) + i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) + while i < 10: + i += 1 + if i > 5: + x += 110 + return x + x += i + return x + + +@declarative +def test_return_in_for(x): + x = fluid.dygraph.to_variable(x) + for i in range(10): + if i <= 4: + x += 1 + continue + else: + return x + 10086 + return x - 1 + + +@declarative +def test_recursive_return(x): + x = fluid.dygraph.to_variable(x) + return dyfunc_with_if_else(x) + + +class TestReturnBase(unittest.TestCase): + def setUp(self): + self.input = np.ones((1)).astype('int32') + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.init_dygraph_func() + self.program_translator = ProgramTranslator() + + def init_dygraph_func(self): + self.dygraph_func = test_return_base + + def run_dygraph_mode(self): + self.program_translator.enable(False) + with fluid.dygraph.guard(): + res = self.dygraph_func(self.input) + return res.numpy() + + def run_static_mode(self): + self.program_translator.enable(True) + with fluid.dygraph.guard(): + res = self.dygraph_func(self.input) + return res.numpy() + + def test_transformed_static_result(self): + static_res = self.run_static_mode() + dygraph_res = self.run_dygraph_mode() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + + +class TestInsideFuncBase(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_inside_func_base + + +class TestReturnIf(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_if + + +class TestReturnIfElse(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_if_else + + +class TestReturnInWhile(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_in_while + + +class TestReturnInFor(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_in_for + + +class TestRecursiveReturn(TestReturnBase): + def init_dygraph_func(self): + self.input = self.input.astype(np.float32) + self.dygraph_func = test_recursive_return + + +if __name__ == '__main__': + unittest.main()