未验证 提交 a7433cc3 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Fix bug: the return statement should be transformed to an equivalent...

[Dy2Stat] Fix bug: the return statement should be transformed to an equivalent Paddle/Python if statement, which depends on if conditions of the return stmt. (#29165)
上级 4a0a8701
......@@ -20,6 +20,7 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
__all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
......@@ -251,10 +252,7 @@ class ReturnTransformer(gast.NodeTransformer):
value=return_value_nodes)
node.body.insert(0, assign_return_value_node)
node.body[:0] = assign_zero_nodes
# Prepend control flow boolean nodes such as '__return@1 = False'
for name in self.return_name[node]:
assign_false_node = create_fill_constant_node(name, False)
node.body.insert(0, assign_false_node)
# Prepend no value placeholders
for name in self.return_no_value_name[node]:
assign_no_value_node = create_fill_constant_node(
......@@ -270,6 +268,8 @@ class ReturnTransformer(gast.NodeTransformer):
self.return_name[cur_func_node].append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length(
cur_func_node)
parent_node_of_return = self.ancestor_nodes[-2]
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1]
......@@ -277,18 +277,21 @@ class ReturnTransformer(gast.NodeTransformer):
"body") and index_in_list(ancestor.body, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(
ancestor.body, cur_node, return_name, max_return_length)
ancestor.body, cur_node, return_name, max_return_length,
parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name)
ancestor.body, cur_node, return_name, parent_node_of_return)
elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
return_name,
max_return_length)
self._replace_return_in_stmt_list(
ancestor.orelse, cur_node, return_name,
max_return_length, parent_node_of_return)
self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name)
ancestor.orelse, cur_node, return_name,
parent_node_of_return)
# If return node in while loop, add `not return_name` in gast.While.test
if isinstance(ancestor, gast.While):
cond_var_node = gast.UnaryOp(
op=gast.Not(),
......@@ -301,6 +304,7 @@ class ReturnTransformer(gast.NodeTransformer):
op=gast.And(), values=[ancestor.test, cond_var_node])
continue
# If return node in for loop, add `not return_name` in gast.While.test
if isinstance(ancestor, gast.For):
cond_var_node = gast.UnaryOp(
op=gast.Not(),
......@@ -321,12 +325,24 @@ class ReturnTransformer(gast.NodeTransformer):
# return_node is replaced so we shouldn't return here
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
max_return_length):
max_return_length, parent_node_of_return):
assert max_return_length >= 0, "Input illegal max_return_length"
i = index_in_list(stmt_list, return_node)
if i == -1:
return False
assign_nodes = [create_fill_constant_node(return_name, True)]
assign_nodes = []
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_true_node = gast.parse(node_str).body[0]
assign_nodes.append(assign_true_node)
cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node)
if return_length < max_return_length:
......@@ -409,14 +425,15 @@ class ReturnTransformer(gast.NodeTransformer):
stmt_list[i:] = assign_nodes
return True
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
return_name):
def _replace_after_node_to_if_in_stmt_list(
self, stmt_list, node, return_name, parent_node_of_return):
i = index_in_list(stmt_list, node)
if i < 0 or i >= len(stmt_list):
return False
if i == len(stmt_list) - 1:
# No need to add, we consider this as added successfully
return True
if_stmt = gast.If(test=gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
......@@ -426,5 +443,16 @@ class ReturnTransformer(gast.NodeTransformer):
type_comment=None)),
body=stmt_list[i + 1:],
orelse=[])
stmt_list[i + 1:] = [if_stmt]
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = False'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0]
stmt_list[i:i] = [assign_false_node]
return True
......@@ -18,12 +18,14 @@ import six
import gast
from paddle.fluid import core
from paddle.fluid.framework import Variable
from paddle.fluid.layers import fill_constant
from paddle.fluid.layer_helper import LayerHelper
__all__ = [
'create_fill_constant_node', 'create_static_variable_gast_node',
'data_layer_not_check', 'to_static_variable', 'to_static_variable_gast_node'
'create_bool_as_type', 'create_fill_constant_node',
'create_static_variable_gast_node', 'data_layer_not_check',
'to_static_variable', 'to_static_variable_gast_node'
]
......@@ -122,3 +124,13 @@ def to_static_variable(x):
return fill_constant(shape=[1], dtype='int64', value=x)
return x
def create_bool_as_type(x, value=True):
'''
Create a bool variable, which type is the same as x.
'''
if isinstance(x, Variable):
return fill_constant(shape=[1], value=value, dtype="bool")
else:
return value
......@@ -62,10 +62,7 @@ def get_source_code(func):
class StaticCode1():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
__return_1 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False)
__return_0 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False)
__return_value_init_0 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
__return_value_0 = __return_value_init_0
......@@ -81,11 +78,13 @@ class StaticCode1():
x_v = paddle.jit.dy2static.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, ))
__return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
def true_fn_1(__return_0, __return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
__return_0 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_value_0 = loss
return __return_0, __return_value_0
......@@ -97,27 +96,25 @@ class StaticCode1():
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0)))
def true_fn_2(__return_1, __return_value_0, x_v):
__return_1 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_1, __return_value_0
return __return_value_0
def false_fn_2(__return_1, __return_value_0):
return __return_1, __return_value_0
def false_fn_2(__return_value_0):
return __return_value_0
__return_1, __return_value_0 = (paddle.jit.dy2static.convert_ifelse(
__return_value_0 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2,
false_fn_2, (__return_1, __return_value_0, x_v),
(__return_1, __return_value_0), (__return_1, __return_value_0)))
false_fn_2, (__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, ))
return __return_value_0
class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
__return_3 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False)
__return_2 = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=False)
__return_value_init_1 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
__return_value_1 = __return_value_init_1
......@@ -133,35 +130,37 @@ class StaticCode2():
x_v = paddle.jit.dy2static.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, ))
__return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
def true_fn_4(__return_2, __return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
__return_2 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_value_1 = loss
return __return_2, __return_value_1
def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1
__return_2, __return_value_1 = (paddle.jit.dy2static.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1)))
__return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse(
label is not None, true_fn_4, false_fn_4, (
__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1))
def true_fn_5(__return_3, __return_value_1, x_v):
__return_3 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_3, __return_value_1
return __return_value_1
def false_fn_5(__return_3, __return_value_1):
return __return_3, __return_value_1
def false_fn_5(__return_value_1):
return __return_value_1
__return_3, __return_value_1 = (paddle.jit.dy2static.convert_ifelse(
__return_value_1 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5,
false_fn_5, (__return_3, __return_value_1, x_v),
(__return_3, __return_value_1), (__return_3, __return_value_1)))
false_fn_5, (__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, ))
return __return_value_1
......
......@@ -14,13 +14,15 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.jit import to_static
from paddle.jit import ProgramTranslator
import unittest
import numpy as np
from ifelse_simple_func import dyfunc_with_if_else
SEED = 2020
......@@ -179,6 +181,26 @@ def test_return_tuple_many_values(x):
return (x, y, z)
def inner_func(x):
a = 2
if a < 0:
y = x + 1
return y
y = x * 2
return y
@to_static
def test_return_without_paddle_cond(x):
# y shape is [10]
y = paddle.ones([10])
# the shape of inner_func(y) should be [10], not [1]
y = inner_func(y)
y = paddle.reshape(y, [2, 5])
return y
class TestReturnBase(unittest.TestCase):
def setUp(self):
self.input = np.ones((1)).astype('int32')
......@@ -297,5 +319,10 @@ class TestReturnTupleManyValue(TestReturnBase):
self.dygraph_func = test_return_tuple_many_values
class TestReturnSpecial(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_without_paddle_cond
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@
from __future__ import print_function
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_as_type #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check #DEFINE_ALIAS
......@@ -21,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_var
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node #DEFINE_ALIAS
__all__ = [
'create_fill_constant_node', 'create_static_variable_gast_node',
'data_layer_not_check', 'to_static_variable', 'to_static_variable_gast_node'
'create_bool_as_type', 'create_fill_constant_node',
'create_static_variable_gast_node', 'data_layer_not_check',
'to_static_variable', 'to_static_variable_gast_node'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册