未验证 提交 6f631a27 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Basic Support for Grammar 'return' (#25176)

This PR added basic support for 'return' grammar in dy2stat. It supports the control flow of 'return'.

The basics idea is using a return value variable to store the early return statements and boolean state variables with if-else to skip the statements after the return statements.

**This PR is very basic support. There are some corner cases I didn't develop/test**. For example, 'return None', 'return different length of variables', 'return non-tensor and tensor together', 'no return statement'. **These corner cases will be done in my next PRs**. Target date is this week.

**Note**: 
1. for the unit test, I changed test_program_translator.py because the StaticCode of `dyfunc_with_if_else` will change. To guarantee the correctness of `dyfunc_with_if_else`, I also run it in `TestRecursiveReturn` in test_return.py.

2. I commented the early return code in bert_dygraph_model.py because 'return different length of variables' is unsupported now. I also know that there are some other models used early return and we didn't enable it in the unit test. I will add support for it in next PRs and then re-enable those tests.
上级 1458cc0c
...@@ -29,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor ...@@ -29,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
...@@ -71,6 +72,9 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -71,6 +72,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform break/continue in loops # Transform break/continue in loops
BreakContinueTransformer(node_wrapper).transform() BreakContinueTransformer(node_wrapper).transform()
# Transform return in functions
ReturnTransformer(node_wrapper).transform()
# Transform logical and/or/not # Transform logical and/or/not
LogicalTransformer(node_wrapper).transform() LogicalTransformer(node_wrapper).transform()
......
...@@ -49,7 +49,7 @@ class ForToWhileTransformer(gast.NodeTransformer): ...@@ -49,7 +49,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
new_stmts = self.get_for_stmt_nodes(body_list[i]) new_stmts = self.get_for_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts body_list[i:i + 1] = new_stmts
i += len(new_stmts) i += len(new_stmts)
return return new_stmts
if hasattr(self.parent_node, 'orelse'): if hasattr(self.parent_node, 'orelse'):
body_list = self.parent_node.orelse body_list = self.parent_node.orelse
i = index_in_list(body_list, self.loop_node) i = index_in_list(body_list, self.loop_node)
...@@ -57,7 +57,7 @@ class ForToWhileTransformer(gast.NodeTransformer): ...@@ -57,7 +57,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
new_stmts = self.get_for_stmt_nodes(body_list[i]) new_stmts = self.get_for_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts body_list[i:i + 1] = new_stmts
i += len(new_stmts) i += len(new_stmts)
return return new_stmts
raise ValueError( raise ValueError(
"parent_node doesn't contain the loop_node in ForToWhileTransformer") "parent_node doesn't contain the loop_node in ForToWhileTransformer")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['ReturnTransformer']
# Constant for the name of the variable which stores the boolean state that we
# should return
RETURN_PREFIX = '__return'
# Constant for the name of the variable which stores the final return value
RETURN_VALUE_PREFIX = '__return_value'
class ReturnPreAnalysisVisitor(gast.NodeVisitor):
"""
Visits gast Tree and pre-analyze the information about 'return'.
"""
def __init__(self, root_node):
self.root = root_node
# A list to store where the current function is.
self.function_def = []
# Mapping from gast.FunctionDef node to the number of return statements
# Python allows define function inside function so we have to handle it
self.count_return = {}
self.visit(self.root)
def visit_FunctionDef(self, node):
self.function_def.append(node)
self.count_return[node] = 0
self.generic_visit(node)
self.function_def.pop()
return node
def visit_Return(self, node):
assert len(
self.function_def) > 0, "Found 'return' statement out of function."
cur_func = self.function_def[-1]
if cur_func in self.count_return:
self.count_return[cur_func] += 1
else:
self.count_return[cur_func] = 1
self.generic_visit(node)
def get_func_return_count(self, func_node):
return self.count_return[func_node]
def set_func_return_count(self, func_node, count):
self.count_return[func_node] = count
class ReturnTransformer(gast.NodeTransformer):
"""
Transforms return statements into equivalent python statements containing
only one return statement at last. The basics idea is using a return value
variable to store the early return statements and boolean states with
if-else to skip the statements after the return.
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.ancestor_nodes = []
# The name of the variable which stores the final return value
# Mapping from FunctionDef node to string
self.return_value_name = {}
# The names of the variable which stores the boolean state that skip
# statments. Mapping from FunctionDef node to list
self.return_name = {}
# A list of FunctionDef to store where the current function is.
self.function_def = []
def transform(self):
self.visit(self.root)
def generic_visit(self, node):
# Because we change ancestor nodes during visit_Return, not current
# node, original generic_visit of NodeTransformer will visit node
# which may be deleted. To prevent that node being added into
# transformed AST, We self-write a generic_visit and visit
for field, value in gast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(value, gast.AST):
self.visit(value)
def visit(self, node):
"""
Self-defined visit for appending ancestor
"""
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_FunctionDef(self, node):
self.function_def.append(node)
self.return_value_name[node] = None
self.return_name[node] = []
pre_analysis = ReturnPreAnalysisVisitor(node)
while pre_analysis.get_func_return_count(node) > 1:
self.generic_visit(node)
pre_analysis = ReturnPreAnalysisVisitor(node)
# prepend initialization of final return and append final return statement
value_name = self.return_value_name[node]
if value_name is not None:
node.body.append(
gast.Return(value=gast.Name(
id=value_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)))
assign_zero_node = create_fill_constant_node(value_name, 0.0)
node.body.insert(0, assign_zero_node)
# Prepend control flow boolean nodes such as '__return@1 = False'
for name in self.return_name[node]:
assign_false_node = create_fill_constant_node(name, False)
node.body.insert(0, assign_false_node)
self.function_def.pop()
return node
def visit_Return(self, node):
cur_func_node = self.function_def[-1]
return_name = unique_name.generate(RETURN_PREFIX)
self.return_name[cur_func_node].append(return_name)
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1]
if hasattr(ancestor,
"body") and index_in_list(ancestor.body, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.body, cur_node,
return_name)
self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name)
elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
return_name)
self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name)
if isinstance(ancestor, gast.While):
cond_var_node = gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
id=return_name,
ctx=gast.Load(),
annotation=None,
type_comment=None))
ancestor.test = gast.BoolOp(
op=gast.And(), values=[ancestor.test, cond_var_node])
continue
if isinstance(ancestor, gast.For):
cond_var_node = gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
id=return_name,
ctx=gast.Load(),
annotation=None,
type_comment=None))
parent_node = self.ancestor_nodes[ancestor_index - 1]
for_to_while = ForToWhileTransformer(parent_node, ancestor,
cond_var_node)
new_stmts = for_to_while.transform()
while_node = new_stmts[-1]
self.ancestor_nodes[ancestor_index] = while_node
if ancestor == cur_func_node:
break
# return_node is replaced so we shouldn't return here
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name):
i = index_in_list(stmt_list, return_node)
if i == -1:
return False
assign_nodes = [create_fill_constant_node(return_name, True)]
if return_node.value is not None:
cur_func_node = self.function_def[-1]
if self.return_value_name[cur_func_node] is None:
self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
assign_nodes.append(
gast.Assign(
targets=[
gast.Name(
id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=return_node.value))
stmt_list[i:] = assign_nodes
return True
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
return_name):
i = index_in_list(stmt_list, node)
if i < 0 or i >= len(stmt_list):
return False
if i == len(stmt_list) - 1:
# No need to add, we consider this as added successfully
return True
if_stmt = gast.If(test=gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
id=return_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)),
body=stmt_list[i + 1:],
orelse=[])
stmt_list[i + 1:] = [if_stmt]
return True
...@@ -247,8 +247,11 @@ class BertModelLayer(Layer): ...@@ -247,8 +247,11 @@ class BertModelLayer(Layer):
enc_output = self._encoder(emb_out, n_head_self_attn_mask) enc_output = self._encoder(emb_out, n_head_self_attn_mask)
if not self.return_pooled_out: # TODO(zhhsplendid): uncomment this in next PR which we support various
return enc_output # length of early return
#
#if not self.return_pooled_out:
# return enc_output
next_sent_feat = fluid.layers.slice( next_sent_feat = fluid.layers.slice(
input=enc_output, axes=[1], starts=[0], ends=[1]) input=enc_output, axes=[1], starts=[0], ends=[1])
next_sent_feat = self.pooled_fc(next_sent_feat) next_sent_feat = self.pooled_fc(next_sent_feat)
......
...@@ -63,6 +63,13 @@ def get_source_code(func): ...@@ -63,6 +63,13 @@ def get_source_code(func):
class StaticCode1(): class StaticCode1():
# TODO: Transform return statement # TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
__return_1 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_0 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_value_0 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
def true_fn_0(x_v): def true_fn_0(x_v):
x_v = x_v - 1 x_v = x_v - 1
return x_v return x_v
...@@ -75,45 +82,94 @@ class StaticCode1(): ...@@ -75,45 +82,94 @@ class StaticCode1():
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, )) (x_v, ), (x_v, ))
def true_fn_1(label, x_v): def true_fn_1(__return_0, __return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
return loss __return_0 = fluid.layers.fill_constant(
return shape=[1], dtype='bool', value=True)
__return_value_0 = loss
def false_fn_1(): return __return_0, __return_value_0
return
def false_fn_1(__return_0, __return_value_0):
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( return __return_0, __return_value_0
label is not None, true_fn_1, false_fn_1, (label, x_v), (), ())
return x_v __return_0, __return_value_0 = (
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0)))
def true_fn_2(__return_1, __return_value_0, x_v):
__return_1 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
__return_value_0 = x_v
return __return_1, __return_value_0
def false_fn_2(__return_1, __return_value_0):
return __return_1, __return_value_0
__return_1, __return_value_0 = (
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.dygraph.dygraph_to_static.convert_operators.
convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_1, __return_value_0, x_v),
(__return_1, __return_value_0), (__return_1, __return_value_0)))
return __return_value_0
class StaticCode2(): class StaticCode2():
# TODO: Transform return statement # TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else(x_v, label=None):
def true_fn_2(x_v): __return_3 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_2 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_value_1 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
def true_fn_3(x_v):
x_v = x_v - 1 x_v = x_v - 1
return x_v return x_v
def false_fn_2(x_v): def false_fn_3(x_v):
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, )) (x_v, ), (x_v, ))
def true_fn_3(label, x_v): def true_fn_4(__return_2, __return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
return loss __return_2 = fluid.layers.fill_constant(
return shape=[1], dtype='bool', value=True)
__return_value_1 = loss
def false_fn_3(): return __return_2, __return_value_1
return
def false_fn_4(__return_2, __return_value_1):
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( return __return_2, __return_value_1
label is not None, true_fn_3, false_fn_3, (label, x_v), (), ())
return x_v __return_2, __return_value_1 = (
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1)))
def true_fn_5(__return_3, __return_value_1, x_v):
__return_3 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=True)
__return_value_1 = x_v
return __return_3, __return_value_1
def false_fn_5(__return_3, __return_value_1):
return __return_3, __return_value_1
__return_3, __return_value_1 = (
fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(
fluid.dygraph.dygraph_to_static.convert_operators.
convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_3, __return_value_1, x_v),
(__return_3, __return_value_1), (__return_3, __return_value_1)))
return __return_value_1
class NetWithError(fluid.dygraph.layers.Layer): class NetWithError(fluid.dygraph.layers.Layer):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator
from ifelse_simple_func import dyfunc_with_if_else
SEED = 2020
np.random.seed(SEED)
@declarative
def test_return_base(x):
x = fluid.dygraph.to_variable(x)
return x
@declarative
def test_inside_func_base(x):
x = fluid.dygraph.to_variable(x)
def inner_func(x):
return x
return inner_func(x)
@declarative
def test_return_if(x):
x = fluid.dygraph.to_variable(x)
if x < 0:
x -= 1
return -x
x += 3
return x
@declarative
def test_return_if_else(x):
x = fluid.dygraph.to_variable(x)
if x > 0:
x += 10086
return x
x -= 3 # useless statement to test our code can handle it.
else:
x += 6666
return x
x -= 8888 # useless statement to test our code can handle it.
@declarative
def test_return_in_while(x):
x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
if i > 5:
x += 110
return x
x += i
return x
@declarative
def test_return_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
if i <= 4:
x += 1
continue
else:
return x + 10086
return x - 1
@declarative
def test_recursive_return(x):
x = fluid.dygraph.to_variable(x)
return dyfunc_with_if_else(x)
class TestReturnBase(unittest.TestCase):
def setUp(self):
self.input = np.ones((1)).astype('int32')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
self.program_translator = ProgramTranslator()
def init_dygraph_func(self):
self.dygraph_func = test_return_base
def run_dygraph_mode(self):
self.program_translator.enable(False)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
return res.numpy()
def run_static_mode(self):
self.program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
return res.numpy()
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestInsideFuncBase(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_inside_func_base
class TestReturnIf(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_if
class TestReturnIfElse(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_if_else
class TestReturnInWhile(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_while
class TestReturnInFor(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_in_for
class TestRecursiveReturn(TestReturnBase):
def init_dygraph_func(self):
self.input = self.input.astype(np.float32)
self.dygraph_func = test_recursive_return
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册