未验证 提交 1950a360 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Refine ifelse early return (#43328)

* Refine ifelse early return
上级 083d769b
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import os import os
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
...@@ -87,6 +88,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -87,6 +88,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.visit(node_wrapper.node) self.visit(node_wrapper.node)
transformers = [ transformers = [
EarlyReturnTransformer,
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow ListTransformer, # List used in control flow
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
class EarlyReturnTransformer(gast.NodeTransformer):
"""
Transform if/else return statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def is_define_return_in_if(self, node):
assert isinstance(
node, gast.If
), "Type of input node should be gast.If, but received %s ." % type(
node)
for child in node.body:
if isinstance(child, gast.Return):
return True
return False
def visit_block_nodes(self, nodes):
result_nodes = []
destination_nodes = result_nodes
for node in nodes:
rewritten_node = self.visit(node)
if isinstance(rewritten_node, (list, tuple)):
destination_nodes.extend(rewritten_node)
else:
destination_nodes.append(rewritten_node)
# append other nodes to if.orelse even though if.orelse is not empty
if isinstance(node, gast.If) and self.is_define_return_in_if(node):
destination_nodes = node.orelse
# handle stmt like `if/elif/elif`
while len(destination_nodes) > 0 and \
isinstance(destination_nodes[0], gast.If) and \
self.is_define_return_in_if(destination_nodes[0]):
destination_nodes = destination_nodes[0].orelse
return result_nodes
def visit_If(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node
def visit_While(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node
def visit_For(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node
def visit_FunctionDef(self, node):
node.body = self.visit_block_nodes(node.body)
return node
...@@ -100,6 +100,30 @@ def dyfunc_with_if_else3(x): ...@@ -100,6 +100,30 @@ def dyfunc_with_if_else3(x):
return x return x
def dyfunc_with_if_else_early_return1():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
a = paddle.zeros([2, 2]) + 1
return a
def dyfunc_with_if_else_early_return2():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
elif x == 1:
c = paddle.zeros([2, 2]) + 1
d = paddle.zeros([3, 3]) + 1
return c, d
e = paddle.zeros([2, 2]) + 3
return e
def dyfunc_with_if_else_with_list_geneator(x): def dyfunc_with_if_else_with_list_geneator(x):
if 10 > 5: if 10 > 5:
y = paddle.add_n( y = paddle.add_n(
......
...@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.nn import Linear ...@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst import paddle.jit.dy2static as _jst
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2
np.random.seed(0) np.random.seed(0)
...@@ -83,34 +83,22 @@ class StaticCode1(): ...@@ -83,34 +83,22 @@ class StaticCode1():
x_v = _jst.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, )) (x_v, ))
__return_0 = _jst.create_bool_as_type(label is not None, False)
def true_fn_1(__return_0, __return_value_0, label, x_v): def true_fn_1(__return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = _jst.create_bool_as_type(label is not None, True) __return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss __return_value_0 = loss
return __return_0, __return_value_0
def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0
__return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0))
def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_value_0 return __return_value_0
def false_fn_2(__return_value_0): def false_fn_1(__return_value_0, label, x_v):
__return_1 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = x_v
return __return_value_0 return __return_value_0
__return_value_0 = _jst.convert_ifelse( __return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1,
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2, false_fn_1,
(__return_0, __return_value_0, x_v), (__return_value_0, )) (__return_value_0, label, x_v),
(__return_value_0, label, x_v))
return __return_value_0 return __return_value_0
...@@ -123,45 +111,33 @@ class StaticCode2(): ...@@ -123,45 +111,33 @@ class StaticCode2():
name='__return_value_init_1') name='__return_value_init_1')
__return_value_1 = __return_value_init_1 __return_value_1 = __return_value_init_1
def true_fn_3(x_v): def true_fn_2(x_v):
x_v = x_v - 1 x_v = x_v - 1
return x_v return x_v
def false_fn_3(x_v): def false_fn_2(x_v):
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
x_v = _jst.convert_ifelse( x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ), fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
(x_v, )) (x_v, ))
__return_2 = _jst.create_bool_as_type(label is not None, False)
def true_fn_4(__return_2, __return_value_1, label, x_v): def true_fn_3(__return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label) loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = _jst.create_bool_as_type(label is not None, True) __return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss __return_value_1 = loss
return __return_2, __return_value_1
def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1
__return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1))
def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_value_1 return __return_value_1
def false_fn_5(__return_value_1): def false_fn_3(__return_value_1, label, x_v):
__return_3 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = x_v
return __return_value_1 return __return_value_1
__return_value_1 = _jst.convert_ifelse( __return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3,
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5, false_fn_3,
(__return_2, __return_value_1, x_v), (__return_value_1, )) (__return_value_1, label, x_v),
(__return_value_1, label, x_v))
return __return_value_1 return __return_value_1
...@@ -358,6 +334,21 @@ class TestFunctionTrainEvalMode(unittest.TestCase): ...@@ -358,6 +334,21 @@ class TestFunctionTrainEvalMode(unittest.TestCase):
net.foo.train() net.foo.train()
class TestIfElseEarlyReturn(unittest.TestCase):
def test_ifelse_early_return1(self):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))
def test_ifelse_early_return2(self):
answer = np.zeros([2, 2]) + 3
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))
class TestRemoveCommentInDy2St(unittest.TestCase): class TestRemoveCommentInDy2St(unittest.TestCase):
def func_with_comment(self): def func_with_comment(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册