From a7433cc3795748759e36fc59b8588864844d6786 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sat, 28 Nov 2020 10:45:26 +0800 Subject: [PATCH] [Dy2Stat] Fix bug: the return statement should be transformed to an equivalent Paddle/Python if statement, which depends on if conditions of the return stmt. (#29165) --- .../dygraph_to_static/return_transformer.py | 56 ++++++++++++----- .../dygraph_to_static/variable_trans_func.py | 16 ++++- .../test_program_translator.py | 61 +++++++++---------- .../dygraph_to_static/test_return.py | 31 +++++++++- .../jit/dy2static/variable_trans_func.py | 6 +- 5 files changed, 119 insertions(+), 51 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index ef03e63dbb..4bcd49dc8e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -20,6 +20,7 @@ from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code __all__ = [ 'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer' @@ -251,10 +252,7 @@ class ReturnTransformer(gast.NodeTransformer): value=return_value_nodes) node.body.insert(0, assign_return_value_node) node.body[:0] = assign_zero_nodes - # Prepend control flow boolean nodes such as '__return@1 = False' - for name in self.return_name[node]: - assign_false_node = create_fill_constant_node(name, False) - node.body.insert(0, assign_false_node) + # Prepend no value placeholders for name in self.return_no_value_name[node]: assign_no_value_node = create_fill_constant_node( @@ -270,6 +268,8 @@ class ReturnTransformer(gast.NodeTransformer): self.return_name[cur_func_node].append(return_name) max_return_length = self.pre_analysis.get_func_max_return_length( cur_func_node) + parent_node_of_return = self.ancestor_nodes[-2] + for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): ancestor = self.ancestor_nodes[ancestor_index] cur_node = self.ancestor_nodes[ancestor_index + 1] @@ -277,18 +277,21 @@ class ReturnTransformer(gast.NodeTransformer): "body") and index_in_list(ancestor.body, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list( - ancestor.body, cur_node, return_name, max_return_length) + ancestor.body, cur_node, return_name, max_return_length, + parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( - ancestor.body, cur_node, return_name) + ancestor.body, cur_node, return_name, parent_node_of_return) elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, cur_node) != -1: if cur_node == node: - self._replace_return_in_stmt_list(ancestor.orelse, cur_node, - return_name, - max_return_length) + self._replace_return_in_stmt_list( + ancestor.orelse, cur_node, return_name, + max_return_length, parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( - ancestor.orelse, cur_node, return_name) + ancestor.orelse, cur_node, return_name, + parent_node_of_return) + # If return node in while loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.While): cond_var_node = gast.UnaryOp( op=gast.Not(), @@ -301,6 +304,7 @@ class ReturnTransformer(gast.NodeTransformer): op=gast.And(), values=[ancestor.test, cond_var_node]) continue + # If return node in for loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.For): cond_var_node = gast.UnaryOp( op=gast.Not(), @@ -321,12 +325,24 @@ class ReturnTransformer(gast.NodeTransformer): # return_node is replaced so we shouldn't return here def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, - max_return_length): + max_return_length, parent_node_of_return): + assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False - assign_nodes = [create_fill_constant_node(return_name, True)] + + assign_nodes = [] + # Here assume that the parent node of return is gast.If + if isinstance(parent_node_of_return, gast.If): + # Prepend control flow boolean nodes such as '__return@1 = True' + node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format( + return_name, + ast_to_source_code(parent_node_of_return.test).strip()) + + assign_true_node = gast.parse(node_str).body[0] + assign_nodes.append(assign_true_node) + cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: @@ -409,14 +425,15 @@ class ReturnTransformer(gast.NodeTransformer): stmt_list[i:] = assign_nodes return True - def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, - return_name): + def _replace_after_node_to_if_in_stmt_list( + self, stmt_list, node, return_name, parent_node_of_return): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True + if_stmt = gast.If(test=gast.UnaryOp( op=gast.Not(), operand=gast.Name( @@ -426,5 +443,16 @@ class ReturnTransformer(gast.NodeTransformer): type_comment=None)), body=stmt_list[i + 1:], orelse=[]) + stmt_list[i + 1:] = [if_stmt] + + # Here assume that the parent node of return is gast.If + if isinstance(parent_node_of_return, gast.If): + # Prepend control flow boolean nodes such as '__return@1 = False' + node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format( + return_name, + ast_to_source_code(parent_node_of_return.test).strip()) + assign_false_node = gast.parse(node_str).body[0] + + stmt_list[i:i] = [assign_false_node] return True diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 617c05c336..673d30cffb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -18,12 +18,14 @@ import six import gast from paddle.fluid import core +from paddle.fluid.framework import Variable from paddle.fluid.layers import fill_constant from paddle.fluid.layer_helper import LayerHelper __all__ = [ - 'create_fill_constant_node', 'create_static_variable_gast_node', - 'data_layer_not_check', 'to_static_variable', 'to_static_variable_gast_node' + 'create_bool_as_type', 'create_fill_constant_node', + 'create_static_variable_gast_node', 'data_layer_not_check', + 'to_static_variable', 'to_static_variable_gast_node' ] @@ -122,3 +124,13 @@ def to_static_variable(x): return fill_constant(shape=[1], dtype='int64', value=x) return x + + +def create_bool_as_type(x, value=True): + ''' + Create a bool variable, which type is the same as x. + ''' + if isinstance(x, Variable): + return fill_constant(shape=[1], value=value, dtype="bool") + else: + return value diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 00b2d8dd1a..2ea3e36909 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -62,10 +62,7 @@ def get_source_code(func): class StaticCode1(): - # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): - __return_1 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False) - __return_0 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False) __return_value_init_0 = paddle.fluid.layers.fill_constant( shape=[1], dtype='float64', value=0.0) __return_value_0 = __return_value_init_0 @@ -81,11 +78,13 @@ class StaticCode1(): x_v = paddle.jit.dy2static.convert_ifelse( fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), (x_v, ), (x_v, )) + __return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None, + False) def true_fn_1(__return_0, __return_value_0, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - __return_0 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='bool', value=True) + __return_0 = paddle.jit.dy2static.create_bool_as_type( + label is not None, True) __return_value_0 = loss return __return_0, __return_value_0 @@ -97,27 +96,25 @@ class StaticCode1(): (__return_0, __return_value_0, label, x_v), (__return_0, __return_value_0), (__return_0, __return_value_0))) - def true_fn_2(__return_1, __return_value_0, x_v): - __return_1 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='bool', value=True) + def true_fn_2(__return_0, __return_value_0, x_v): + __return_1 = paddle.jit.dy2static.create_bool_as_type( + paddle.jit.dy2static.convert_logical_not(__return_0), True) __return_value_0 = x_v - return __return_1, __return_value_0 + return __return_value_0 - def false_fn_2(__return_1, __return_value_0): - return __return_1, __return_value_0 + def false_fn_2(__return_value_0): + return __return_value_0 - __return_1, __return_value_0 = (paddle.jit.dy2static.convert_ifelse( + __return_value_0 = paddle.jit.dy2static.convert_ifelse( paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2, - false_fn_2, (__return_1, __return_value_0, x_v), - (__return_1, __return_value_0), (__return_1, __return_value_0))) + false_fn_2, (__return_0, __return_value_0, + x_v), (__return_value_0, ), (__return_value_0, )) return __return_value_0 class StaticCode2(): # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): - __return_3 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False) - __return_2 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False) __return_value_init_1 = paddle.fluid.layers.fill_constant( shape=[1], dtype='float64', value=0.0) __return_value_1 = __return_value_init_1 @@ -133,35 +130,37 @@ class StaticCode2(): x_v = paddle.jit.dy2static.convert_ifelse( fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), (x_v, ), (x_v, )) + __return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None, + False) def true_fn_4(__return_2, __return_value_1, label, x_v): loss = fluid.layers.cross_entropy(x_v, label) - __return_2 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='bool', value=True) + __return_2 = paddle.jit.dy2static.create_bool_as_type( + label is not None, True) __return_value_1 = loss return __return_2, __return_value_1 def false_fn_4(__return_2, __return_value_1): return __return_2, __return_value_1 - __return_2, __return_value_1 = (paddle.jit.dy2static.convert_ifelse( - label is not None, true_fn_4, false_fn_4, - (__return_2, __return_value_1, label, x_v), - (__return_2, __return_value_1), (__return_2, __return_value_1))) + __return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse( + label is not None, true_fn_4, false_fn_4, ( + __return_2, __return_value_1, label, x_v), + (__return_2, __return_value_1), (__return_2, __return_value_1)) - def true_fn_5(__return_3, __return_value_1, x_v): - __return_3 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='bool', value=True) + def true_fn_5(__return_2, __return_value_1, x_v): + __return_3 = paddle.jit.dy2static.create_bool_as_type( + paddle.jit.dy2static.convert_logical_not(__return_2), True) __return_value_1 = x_v - return __return_3, __return_value_1 + return __return_value_1 - def false_fn_5(__return_3, __return_value_1): - return __return_3, __return_value_1 + def false_fn_5(__return_value_1): + return __return_value_1 - __return_3, __return_value_1 = (paddle.jit.dy2static.convert_ifelse( + __return_value_1 = paddle.jit.dy2static.convert_ifelse( paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5, - false_fn_5, (__return_3, __return_value_1, x_v), - (__return_3, __return_value_1), (__return_3, __return_value_1))) + false_fn_5, (__return_2, __return_value_1, + x_v), (__return_value_1, ), (__return_value_1, )) return __return_value_1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py index f592b7ed24..7ab60082c3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -14,13 +14,15 @@ from __future__ import print_function -import unittest -import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.jit import to_static from paddle.jit import ProgramTranslator +import unittest +import numpy as np + from ifelse_simple_func import dyfunc_with_if_else SEED = 2020 @@ -179,6 +181,26 @@ def test_return_tuple_many_values(x): return (x, y, z) +def inner_func(x): + a = 2 + if a < 0: + y = x + 1 + return y + y = x * 2 + return y + + +@to_static +def test_return_without_paddle_cond(x): + # y shape is [10] + y = paddle.ones([10]) + + # the shape of inner_func(y) should be [10], not [1] + y = inner_func(y) + y = paddle.reshape(y, [2, 5]) + return y + + class TestReturnBase(unittest.TestCase): def setUp(self): self.input = np.ones((1)).astype('int32') @@ -297,5 +319,10 @@ class TestReturnTupleManyValue(TestReturnBase): self.dygraph_func = test_return_tuple_many_values +class TestReturnSpecial(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_without_paddle_cond + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/variable_trans_func.py b/python/paddle/jit/dy2static/variable_trans_func.py index 08c057934a..2deb1bbb0e 100644 --- a/python/paddle/jit/dy2static/variable_trans_func.py +++ b/python/paddle/jit/dy2static/variable_trans_func.py @@ -14,6 +14,7 @@ from __future__ import print_function +from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_as_type #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check #DEFINE_ALIAS @@ -21,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_var from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node #DEFINE_ALIAS __all__ = [ - 'create_fill_constant_node', 'create_static_variable_gast_node', - 'data_layer_not_check', 'to_static_variable', 'to_static_variable_gast_node' + 'create_bool_as_type', 'create_fill_constant_node', + 'create_static_variable_gast_node', 'data_layer_not_check', + 'to_static_variable', 'to_static_variable_gast_node' ] -- GitLab