未验证 提交 85292e0b 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Fix bug of convert_logical_and/convert_logical_or: the...

[Dynamic-to-Static] Fix bug of convert_logical_and/convert_logical_or: the operands are executed sequentially(#28993)

  
  1) The operands are executed sequentially according to the running logic of Python.
  
  2) If the left hand operand is True(for convert_logical_or)/False(for convert_logical_and), the right hand operand should be executed.
上级 96126532
......@@ -56,25 +56,39 @@ def _run_py_while(cond, body, loop_vars):
return loop_vars
def convert_logical_and(x, y):
def convert_logical_and(x_func, y_func):
"""
A function representation of a Python ``and`` statement.
Args:
x(bool|Tensor): Left hand operand of ``and`` operator.
y(bool|Tensor): Right hand operand of ``and`` operator.
x_func(callable): x_func() is the left hand operand of ``and`` operator. x_func() is bool or Tensor.
y_func(callable): y_func() is the right hand operand of ``and`` operator. y_func() is bool or Tensor.
Returns:
A python bool variable or a bool Tensor.
"""
if isinstance(x, Variable) and isinstance(y, Variable):
return _run_paddle_logical_and(x, y)
NOTE(liym27):
1) The operands are executed sequentially according to the running logic of Python. So here the arguments
should be callable.
2) If the left hand operand is False, the right hand operand should be executed.
For example:
a = x > 1 and y < 1
Transformed code:
a = paddle.jit.dy2static.convert_logical_and(lambda:x>1, lambda:y<1)
In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1` must be run after `lambda:x>1`. And
if `x>1` is False, `y<1` should NOT be run.
"""
x_value = x_func()
if not isinstance(x_value, Variable):
return _run_py_logical_and(lambda: x_value, y_func)
if not isinstance(x, Variable):
return _run_py_logical_and(x, y)
y_value = y_func()
if not isinstance(y_value, Variable):
return _run_py_logical_and(lambda: y_value, lambda: x_value)
return _run_py_logical_and(y, x)
return _run_paddle_logical_and(x_value, y_value)
def _run_paddle_logical_and(x, y):
......@@ -83,31 +97,49 @@ def _run_paddle_logical_and(x, y):
return logical_and(x, y)
def _run_py_logical_and(x, y):
assert not isinstance(x, Variable)
# NOTE: Returns y if x is True
return x and y
def _run_py_logical_and(x_func, y_func):
x_value = x_func()
assert not isinstance(x_value, Variable)
# NOTE(liym27):
# 1. Returns y_func() if x_value is False;
# 2. If x_value is False, y_func() should not be run.
return x_value and y_func()
def convert_logical_or(x, y):
def convert_logical_or(x_func, y_func):
"""
A function representation of a Python ``or`` statement.
Args:
x(bool|Tensor): Left hand operand of ``or`` operator.
y(bool|Tensor): Right hand operand of ``or`` operator.
x_func(callable): x_func() is the left hand operand of ``or`` operator. x_func() is bool or Tensor.
y_func(callable): y_func() is the right hand operand of ``or`` operator. y_func() is bool or Tensor.
Returns:
A python bool variable or a bool Tensor.
"""
if isinstance(x, Variable) and isinstance(y, Variable):
return _run_paddle_logical_or(x, y)
NOTE(liym27):
1) The operands are executed sequentially according to the running logic of Python. So here the arguments
should be callable.
2) If the left hand operand is True, the right hand operand should be executed.
if not isinstance(x, Variable):
return _run_py_logical_or(x, y)
For example:
a = x > 1 or y < 1
Transformed code:
a = paddle.jit.dy2static.convert_logical_or(lambda:x>1, lambda:y<1)
return _run_py_logical_or(y, x)
In `convert_logical_or(lambda:x>1, lambda:y<1)`, `lambda:y<1` must be run after `lambda:x>1`. And
if `x>1` is True, `y<1` should NOT be run.
"""
x_value = x_func()
if not isinstance(x_value, Variable):
return _run_py_logical_or(lambda: x_value, y_func)
y_value = y_func()
if not isinstance(y_value, Variable):
return _run_py_logical_or(lambda: y_value, lambda: x_value)
return _run_paddle_logical_or(x_value, y_value)
def _run_paddle_logical_or(x, y):
......@@ -116,10 +148,14 @@ def _run_paddle_logical_or(x, y):
return logical_or(x, y)
def _run_py_logical_or(x, y):
assert not isinstance(x, Variable)
# NOTE: Returns y if x is False
return x or y
def _run_py_logical_or(x_func, y_func):
x_value = x_func()
assert not isinstance(x_value, Variable)
# NOTE(liym27):
# 1. Returns y_func() if x_value is False;
# 2. If x_value is True, y_func() should not be run.
return x_value or y_func()
def convert_logical_not(x):
......@@ -193,7 +229,6 @@ def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn(*true_args) if pred else false_fn(*false_args)
......
......@@ -20,7 +20,13 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class LogicalTransformer(gast.NodeTransformer):
"""
Transform python boolean op into Paddle logical op
Transform python boolean op into Paddle logical op.
For example:
a = x > 1 and y < 1
Transformed code:
a = paddle.jit.dy2static.convert_logical_and(lambda:x>1, lambda:y<1)
"""
def __init__(self, wrapper_root):
......@@ -53,6 +59,12 @@ class LogicalTransformer(gast.NodeTransformer):
return new_node
def _create_bool_op_node(self, nodes, api_type):
'''
NOTE(liym27):
The arguments of function convert_logical_XX should be callable so that they can be run
according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1`
must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run.
'''
assert len(
nodes
) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
......@@ -67,7 +79,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes]
new_node_str = "paddle.jit.dy2static.convert_logical_{}(x={}, y={})".format(
new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format(
api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for logical operators of Dynamic-to-Static.
Only test simple cases here. The complex test samples like nested ifelse
or nested loop have been covered in file test_ifelse.py and test_loop.py"""
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
program_translator = ProgramTranslator()
SEED = 2020
np.random.seed(22)
@paddle.jit.to_static
def test_logical_not(x):
x = paddle.to_tensor(x)
if not x:
x = x - 1
else:
x = x + 1
if x != 10:
x = x - 1
else:
x = x + 1
y = 0
if not y:
x = x + 4
if y != 3:
x = x + 2
return x
@paddle.jit.to_static
def test_logical_not_2(x):
x = paddle.to_tensor(x)
y = None
if y is not None and not y:
x = x + 4
if y != 3:
x = x + 2
return x
@paddle.jit.to_static
def test_logical_and(x):
x = paddle.to_tensor(x)
if x < 10 and x > 1:
x = x - 1
else:
x = x + 1
y = 3
if y < 10 and y > 1:
x = x - 2
else:
x = x + 2
return x
@paddle.jit.to_static
def test_logical_and_2(x):
x = paddle.to_tensor(x)
a = None
# NOTE(liym27):
# because `a is not None` is False, then `a > 1` won't be run,
# which means `convert_logical_and(a is not None, a > 1)` should not
# run a>1.
if a is not None and a > 1:
x = x - 1
else:
x = x + 1
b = 3
if b is not None and b > 1:
x = x - 1
else:
x = x + 1
return x
@paddle.jit.to_static
def test_logical_or(x):
x = paddle.to_tensor(x)
if x < 10 or x > 1:
x = x - 1
else:
x = x + 1
a = 10
if a > 3 or a < 1:
x = x - 1
else:
x = x + 1
return x
@paddle.jit.to_static
def test_logical_or_2(x):
x = paddle.to_tensor(x)
a = None
if x > 1 or a is None or a > 1:
x = x - 1
else:
x = x + 1
return x
@paddle.jit.to_static
def test_logical_not_and_or(x):
x = paddle.to_tensor(x)
a = 1
if x < 10 and (a < 4 or a > 0) or a < -1 or not x > -1:
x = x - 1
else:
x = x + 1
return x
class TestLogicalBase(unittest.TestCase):
def setUp(self):
self.input = np.array([3]).astype('int32')
self.place = paddle.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else paddle.CPUPlace()
self._set_test_func()
def _set_test_func(self):
raise NotImplementedError(
"Method 'set_test_func' should be implemented.")
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard(self.place):
result = self.dygraph_func(self.input)
return result.numpy()
def _run_dygraph(self):
return self._run(to_static=False)
def _run_static(self):
return self._run(to_static=True)
class TestLogicalNot(TestLogicalBase):
def _set_test_func(self):
self.dygraph_func = test_logical_not
def test_transformed_result(self):
dygraph_res = self._run_dygraph()
static_res = self._run_static()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph result is {}\nstatic_result is {}'.format(dygraph_res,
static_res))
class TestLogicalNot2(TestLogicalBase):
def _set_test_func(self):
self.dygraph_func = test_logical_not_2
def test_transformed_result(self):
dygraph_res = self._run_dygraph()
static_res = self._run_static()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph result is {}\nstatic_result is {}'.format(dygraph_res,
static_res))
class TestLogicalAnd(TestLogicalNot):
def _set_test_func(self):
self.dygraph_func = test_logical_and
class TestLogicalAnd2(TestLogicalNot):
def _set_test_func(self):
self.dygraph_func = test_logical_and_2
class TestLogicalOr(TestLogicalNot):
def _set_test_func(self):
self.dygraph_func = test_logical_or
class TestLogicalOr2(TestLogicalNot):
def _set_test_func(self):
self.dygraph_func = test_logical_or_2
class TestLogicalNotAndOr(TestLogicalNot):
def _set_test_func(self):
self.dygraph_func = test_logical_not_and_or
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册