From 001c9fcca28521241b10dc1d2055c34933a7c476 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 3 Jun 2020 11:05:29 +0800 Subject: [PATCH] [Dy2Static]Convert while stmt and convert logical_XX (#24799) * Support convert_while_loop. * Comment code that not supported 'if' in test_break_continue. * Convert int into tensor to support 'if' stmt in for/while loop. * Add unittest to test all cases of convert_logical_XX. * Add unittest to test all cases of convert_while_loop. * Fix bug in LogicalOpTransformer. test=develop --- .../dygraph/dygraph_to_static/__init__.py | 2 + .../dygraph_to_static/convert_operators.py | 138 ++++++++++++++++++ .../dygraph_to_static/loop_transformer.py | 22 ++- .../dygraph_to_static/ifelse_simple_func.py | 9 ++ .../dygraph_to_static/test_break_continue.py | 35 +++-- .../unittests/dygraph_to_static/test_list.py | 3 + .../unittests/dygraph_to_static/test_loop.py | 49 +++++-- 7 files changed, 225 insertions(+), 33 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py index d2d03d65b1b..1f91027e462 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py @@ -32,6 +32,8 @@ from .program_translator import * from . import convert_call_func from .convert_call_func import * +from . import convert_operators + __all__ = [] __all__ += ast_transformer.__all__ __all__ += loop_transformer.__all__ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py new file mode 100644 index 00000000000..21205e20598 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.framework import Variable +from paddle.fluid.layers import control_flow, logical_and, logical_or, logical_not +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable + + +def convert_while_loop(cond, body, loop_vars): + """ + A function representation of a Python ``while`` statement. + + Args: + cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. + body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . + loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . + + Returns: + A list or tuple of variables which returned by ``body`` . + """ + + pred = cond(*loop_vars) + if isinstance(pred, Variable): + loop_vars = _run_paddle_while_loop(cond, body, loop_vars) + else: + loop_vars = _run_py_while(cond, body, loop_vars) + + return loop_vars + + +def _run_paddle_while_loop(cond, body, loop_vars): + loop_vars = [to_static_variable(var) for var in loop_vars] + loop_vars = control_flow.while_loop(cond, body, loop_vars) + return loop_vars + + +def _run_py_while(cond, body, loop_vars): + while cond(*loop_vars): + loop_vars = body(*loop_vars) + return loop_vars + + +def convert_logical_and(x, y): + """ + A function representation of a Python ``and`` statement. + + Args: + x(bool|Variable): Left hand operand of ``and`` operator. + y(bool|Variable): Right hand operand of ``and`` operator. + + Returns: + A python bool variable or a bool Tensor. + """ + + if isinstance(x, Variable) and isinstance(y, Variable): + return _run_paddle_logical_and(x, y) + + if not isinstance(x, Variable): + return _run_py_logical_and(x, y) + + return _run_py_logical_and(y, x) + + +def _run_paddle_logical_and(x, y): + return logical_and(x, y) + + +def _run_py_logical_and(x, y): + assert not isinstance(x, Variable) + # NOTE: Returns y if x is True + return x and y + + +def convert_logical_or(x, y): + """ + A function representation of a Python ``or`` statement. + + Args: + x(bool|Variable): Left hand operand of ``or`` operator. + y(bool|Variable): Right hand operand of ``or`` operator. + + Returns: + A python bool variable or a bool Tensor. + """ + + if isinstance(x, Variable) and isinstance(y, Variable): + return _run_paddle_logical_or(x, y) + + if not isinstance(x, Variable): + return _run_py_logical_or(x, y) + + return _run_py_logical_or(y, x) + + +def _run_paddle_logical_or(x, y): + return logical_or(x, y) + + +def _run_py_logical_or(x, y): + assert not isinstance(x, Variable) + # NOTE: Returns y if x is False + return x or y + + +def convert_logical_not(x): + """ + A function representation of a Python ``not`` statement. + + Args: + x(bool|Variable): Operand of of ``not`` operator. + + Returns: + A python bool variable or a bool Tensor. + """ + + if isinstance(x, Variable): + return _run_paddle_logical_not(x) + else: + return _run_py_logical_not(x) + + +def _run_paddle_logical_not(x): + return logical_not(x) + + +def _run_py_logical_not(x): + return not x diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index b2369b6debd..b43c20424c3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -59,7 +59,9 @@ def create_while_node(condition_name, body_name, loop_var_names): ] while_args.append(gast.List(elts=assign_targets, ctx=gast.Param())) - while_func_id = gast.parse('fluid.layers.while_loop').body[0].value + while_func_id = gast.parse( + 'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop' + ).body[0].value while_node = gast.Call(func=while_func_id, args=while_args, keywords=[]) assign_node = gast.Assign( targets=[gast.Tuple( @@ -83,7 +85,8 @@ class LogicalOpTransformer(gast.NodeTransformer): self.generic_visit(node) if isinstance(node.op, gast.Not): arg = ast_to_source_code(node.operand) - new_node_str = "fluid.layers.logical_not({})".format(arg) + new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_not({})".format( + arg) # gast.parse returns Module(body=[expr(value=...)]) new_node = gast.parse(new_node_str).body[0].value return new_node @@ -108,11 +111,14 @@ class LogicalOpTransformer(gast.NodeTransformer): if len(nodes) > 2: # Creates logic_and/logic_or node recursively. pre_logic_node = self._create_bool_op_node(nodes[:2], api_type) - post_logic_node = self._create_bool_op_node(nodes[2:], api_type) + if len(nodes[2:]) == 1: + post_logic_node = nodes[2] + else: + post_logic_node = self._create_bool_op_node(nodes[2:], api_type) nodes = [pre_logic_node] + [post_logic_node] args = [ast_to_source_code(child) for child in nodes] - new_node_str = "fluid.layers.logical_{}(x={}, y={})".format( + new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_{}(x={}, y={})".format( api_type, args[0], args[1]) # gast.parse return Module(body=[expr(...)]) new_node = gast.parse(new_node_str).body[0].value @@ -538,10 +544,6 @@ class LoopTransformer(gast.NodeTransformer): return new_stmts def get_while_stmt_nodes(self, node): - # TODO: consider while - else in python - if not self.name_visitor.is_control_flow_loop(node): - return [node] - loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( node) new_stmts = [] @@ -558,10 +560,6 @@ class LoopTransformer(gast.NodeTransformer): if "." not in name: new_stmts.append(create_static_variable_gast_node(name)) - # while x < 10 in dygraph should be convert into static tensor < 10 - for name in loop_var_names: - new_stmts.append(to_static_variable_gast_node(name)) - logical_op_transformer = LogicalOpTransformer(node.test) cond_value_node = logical_op_transformer.transform() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 15bd8237131..b2dbd6cc597 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -14,6 +14,7 @@ from __future__ import print_function +import six import paddle.fluid as fluid @@ -257,6 +258,14 @@ def if_tensor_case(x): # It is equivalent to `if mean != 0` if mean: for i in range(0, 10): + # TODO(liym27): Delete it if the type of parameter `i` can be resolved in "if" stmt + if six.PY2: + i = fluid.layers.fill_constant( + shape=[1], value=i, dtype="int32") + else: + i = fluid.layers.fill_constant( + shape=[1], value=i, dtype="int64") + if i > 5: x += 1 break diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 9a605078fb3..0c7542603d3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -89,29 +89,44 @@ def test_break_in_while(x): def test_break_continue_in_for(x): x = fluid.dygraph.to_variable(x) + + # TODO(liym27): Uncomment code after "if" statement can be transformed correctly. + # for i in range(1, 10, 1): + # if i <= 4: + # x += 1 + # continue + # else: + # x += 10010 + # break + # x += 10086 + + a = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) for i in range(1, 10, 1): - if i <= 4: + if a <= 4: x += 1 + a += 1 continue else: x += 10010 break x += 10086 + return x def test_for_in_else(x): x = fluid.dygraph.to_variable(x) - # Case 1: - if False: - pass - else: - for i in range(0, 10): - if i > 5: - x += 1 - break - x += i + # TODO(liym27): Uncomment code after "if" statement can be transformed correctly. + # # Case 1: + # if False: + # pass + # else: + # for i in range(0, 10): + # if i > 5: + # x += 1 + # break + # x += i # Case 2: if False: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index d65b7967422..d59861a88b5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -156,6 +156,9 @@ def test_list_pop_in_while_loop(x, iter_num): shape=[1], value=iter_num, dtype="int32") a = [] i = 0 + + # TODO(liym27): Delete it if the type of parameter `i` can be resolved in "if" stmt + i = fluid.layers.fill_constant(shape=[1], value=i, dtype="int32") while i < iter_num: a.append(x + i) i += 1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 00a7cc4f757..b3a209fbe29 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -29,20 +29,25 @@ np.random.seed(SEED) def while_loop_dyfunc(x): i = fluid.dygraph.to_variable(x) - # Use `to_variable` so that static analysis can analyze the type of X is Tensor - x = fluid.dygraph.to_variable( - x) # TODO(liym27): Delete it if the type of parameter x can be resolved while x < 10: i = i + x x = x + 1 return i +def while_loop_dyfunc_without_tensor(x): + a = 1 + # There are no tensors in the while condition, which means it's a plain while in python, + # so it wont't be transformed to `while_loop` op. + while not a > 4 and a > 0: + x = x + 1 + a = a + 1 + + return x + + def while_loop_dyfun_with_conflict_var(x): i = fluid.dygraph.to_variable(x) - # Use `to_variable` so that static analysis can analyze the type of X is Tensor - x = fluid.dygraph.to_variable( - x) # TODO(liym27): Delete it if the type of parameter x can be resolved def relu(y): # 'y' is not visible outside the scope. @@ -82,15 +87,24 @@ def for_loop_dyfunc(max_len): def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) - # Use `to_variable` so that static analysis can analyze the type of X is Tensor - x = fluid.dygraph.to_variable( - x) # TODO(liym27): Delete it if the type of parameter x can be resolved while x <= -1 or x < -3 or (x < -7 or x < -5) or (x >= 0 and x < 10): i = i + x x = x + 1 return i +def while_loop_bool_op2(x): + i = fluid.dygraph.to_variable(x) + a = 1 + + # In the while condition, there are both Paddle Variable and non-Variable. + while x < 10 and (a < 4 or a > 0) or a < -1 or not x > -1: + i = i + x + x = x + 1 + a = a + 1 + return i + + def while_loop_class_var(x): class Foo(object): def __init__(self): @@ -120,6 +134,7 @@ def for_loop_class_var(max_len): # TODO(liym27): Delete it if the type of parameter x can be resolved max_len = fluid.layers.fill_constant( shape=[1], value=max_len, dtype="int32") + for i in range(max_len): foo.b = fluid.layers.zeros(shape=[1], dtype='float32') foo.c = foo.b + foo.a @@ -211,10 +226,12 @@ class TestTransformWhileLoop(unittest.TestCase): def _run(self, to_static): with fluid.dygraph.guard(self.place): + # Set the input of dyfunc to VarBase + tensor_x = fluid.dygraph.to_variable(self.x, zero_copy=False) if to_static: - ret = declarative(self.dyfunc)(self.x) + ret = declarative(self.dyfunc)(tensor_x) else: - ret = self.dyfunc(self.x) + ret = self.dyfunc(tensor_x) return ret.numpy() def test_ast_to_func(self): @@ -223,6 +240,11 @@ class TestTransformWhileLoop(unittest.TestCase): self.assertTrue(np.allclose(dygraph_numpy, static_numpy)) +class TestTransformWhileLoopWithoutTensor(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_dyfunc_without_tensor + + class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_dyfun_with_conflict_var @@ -238,6 +260,11 @@ class TestWhileLoopBoolOp(TestTransformWhileLoop): self.dyfunc = while_loop_bool_op +class TestWhileLoopBoolOp2(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = while_loop_bool_op2 + + class TestWhileLoopClassVar(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_class_var -- GitLab