From 1950a3603897dac8598e19bc0890c75d6fd32456 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 14 Jun 2022 17:54:41 +0800 Subject: [PATCH] [Dy2St]Refine ifelse early return (#43328) * Refine ifelse early return --- .../dygraph_to_static/ast_transformer.py | 2 + .../early_return_transformer.py | 88 +++++++++++++++++++ .../dygraph_to_static/ifelse_simple_func.py | 24 +++++ .../test_program_translator.py | 79 ++++++++--------- 4 files changed, 149 insertions(+), 44 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index de53a56468..aa01945ac8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -20,6 +20,7 @@ from __future__ import print_function # See details in https://github.com/serge-sans-paille/gast/ import os from paddle.utils import gast +from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer @@ -87,6 +88,7 @@ class DygraphToStaticAst(gast.NodeTransformer): self.visit(node_wrapper.node) transformers = [ + EarlyReturnTransformer, BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) ListTransformer, # List used in control flow diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py new file mode 100644 index 0000000000..bef1efb042 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.utils import gast +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper + + +class EarlyReturnTransformer(gast.NodeTransformer): + """ + Transform if/else return statement of Dygraph into Static Graph. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.root = wrapper_root.node + + def transform(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + + def is_define_return_in_if(self, node): + assert isinstance( + node, gast.If + ), "Type of input node should be gast.If, but received %s ." % type( + node) + for child in node.body: + if isinstance(child, gast.Return): + return True + return False + + def visit_block_nodes(self, nodes): + result_nodes = [] + destination_nodes = result_nodes + for node in nodes: + rewritten_node = self.visit(node) + + if isinstance(rewritten_node, (list, tuple)): + destination_nodes.extend(rewritten_node) + else: + destination_nodes.append(rewritten_node) + + # append other nodes to if.orelse even though if.orelse is not empty + if isinstance(node, gast.If) and self.is_define_return_in_if(node): + destination_nodes = node.orelse + # handle stmt like `if/elif/elif` + while len(destination_nodes) > 0 and \ + isinstance(destination_nodes[0], gast.If) and \ + self.is_define_return_in_if(destination_nodes[0]): + destination_nodes = destination_nodes[0].orelse + + return result_nodes + + def visit_If(self, node): + node.body = self.visit_block_nodes(node.body) + node.orelse = self.visit_block_nodes(node.orelse) + return node + + def visit_While(self, node): + node.body = self.visit_block_nodes(node.body) + node.orelse = self.visit_block_nodes(node.orelse) + return node + + def visit_For(self, node): + node.body = self.visit_block_nodes(node.body) + node.orelse = self.visit_block_nodes(node.orelse) + return node + + def visit_FunctionDef(self, node): + node.body = self.visit_block_nodes(node.body) + return node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 0c7d2903c3..39565044e7 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -100,6 +100,30 @@ def dyfunc_with_if_else3(x): return x +def dyfunc_with_if_else_early_return1(): + x = paddle.to_tensor([10]) + if x == 0: + a = paddle.zeros([2, 2]) + b = paddle.zeros([3, 3]) + return a, b + a = paddle.zeros([2, 2]) + 1 + return a + + +def dyfunc_with_if_else_early_return2(): + x = paddle.to_tensor([10]) + if x == 0: + a = paddle.zeros([2, 2]) + b = paddle.zeros([3, 3]) + return a, b + elif x == 1: + c = paddle.zeros([2, 2]) + 1 + d = paddle.zeros([3, 3]) + 1 + return c, d + e = paddle.zeros([2, 2]) + 3 + return e + + def dyfunc_with_if_else_with_list_geneator(x): if 10 > 5: y = paddle.add_n( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index cbc6e3c540..cf8be66403 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -29,7 +29,7 @@ from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code import paddle.jit.dy2static as _jst -from ifelse_simple_func import dyfunc_with_if_else +from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2 np.random.seed(0) @@ -83,34 +83,22 @@ class StaticCode1(): x_v = _jst.convert_ifelse( fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), (x_v, )) - __return_0 = _jst.create_bool_as_type(label is not None, False) - def true_fn_1(__return_0, __return_value_0, label, x_v): + def true_fn_1(__return_value_0, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) __return_0 = _jst.create_bool_as_type(label is not None, True) __return_value_0 = loss - return __return_0, __return_value_0 - - def false_fn_1(__return_0, __return_value_0): - return __return_0, __return_value_0 - - __return_0, __return_value_0 = _jst.convert_ifelse( - label is not None, true_fn_1, false_fn_1, - (__return_0, __return_value_0, label, x_v), - (__return_0, __return_value_0)) - - def true_fn_2(__return_0, __return_value_0, x_v): - __return_1 = _jst.create_bool_as_type( - _jst.convert_logical_not(__return_0), True) - __return_value_0 = x_v return __return_value_0 - def false_fn_2(__return_value_0): + def false_fn_1(__return_value_0, label, x_v): + __return_1 = _jst.create_bool_as_type(label is not None, True) + __return_value_0 = x_v return __return_value_0 - __return_value_0 = _jst.convert_ifelse( - _jst.convert_logical_not(__return_0), true_fn_2, false_fn_2, - (__return_0, __return_value_0, x_v), (__return_value_0, )) + __return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1, + false_fn_1, + (__return_value_0, label, x_v), + (__return_value_0, label, x_v)) return __return_value_0 @@ -123,45 +111,33 @@ class StaticCode2(): name='__return_value_init_1') __return_value_1 = __return_value_init_1 - def true_fn_3(x_v): + def true_fn_2(x_v): x_v = x_v - 1 return x_v - def false_fn_3(x_v): + def false_fn_2(x_v): x_v = x_v + 1 return x_v x_v = _jst.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), + fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ), (x_v, )) - __return_2 = _jst.create_bool_as_type(label is not None, False) - def true_fn_4(__return_2, __return_value_1, label, x_v): + def true_fn_3(__return_value_1, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) __return_2 = _jst.create_bool_as_type(label is not None, True) __return_value_1 = loss - return __return_2, __return_value_1 - - def false_fn_4(__return_2, __return_value_1): - return __return_2, __return_value_1 - - __return_2, __return_value_1 = _jst.convert_ifelse( - label is not None, true_fn_4, false_fn_4, - (__return_2, __return_value_1, label, x_v), - (__return_2, __return_value_1)) - - def true_fn_5(__return_2, __return_value_1, x_v): - __return_3 = _jst.create_bool_as_type( - _jst.convert_logical_not(__return_2), True) - __return_value_1 = x_v return __return_value_1 - def false_fn_5(__return_value_1): + def false_fn_3(__return_value_1, label, x_v): + __return_3 = _jst.create_bool_as_type(label is not None, True) + __return_value_1 = x_v return __return_value_1 - __return_value_1 = _jst.convert_ifelse( - _jst.convert_logical_not(__return_2), true_fn_5, false_fn_5, - (__return_2, __return_value_1, x_v), (__return_value_1, )) + __return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3, + false_fn_3, + (__return_value_1, label, x_v), + (__return_value_1, label, x_v)) return __return_value_1 @@ -358,6 +334,21 @@ class TestFunctionTrainEvalMode(unittest.TestCase): net.foo.train() +class TestIfElseEarlyReturn(unittest.TestCase): + + def test_ifelse_early_return1(self): + answer = np.zeros([2, 2]) + 1 + static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) + out = static_func() + self.assertTrue(np.allclose(answer, out.numpy())) + + def test_ifelse_early_return2(self): + answer = np.zeros([2, 2]) + 3 + static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2) + out = static_func() + self.assertTrue(np.allclose(answer, out.numpy())) + + class TestRemoveCommentInDy2St(unittest.TestCase): def func_with_comment(self): -- GitLab