diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index dcb8b686eef0a37f99ce8fec65fa0d5ba0b04b64..383ee9deb195328dd580825aa3bfb26d16c3eb01 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -17,7 +17,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static from paddle.fluid.framework import core, Variable from paddle.fluid.layers import Assert, Print from paddle.fluid.layers import array_length, array_read, array_write, create_array -from paddle.fluid.layers import assign, fill_constant, slice +from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment @@ -272,6 +272,67 @@ def convert_var_shape(x): return x.shape +def convert_shape_compare(left, *args): + """ + A function handles comparison difference between Paddle and Python. + For example, if x and y are Tensors, x.shape == y.shape will return single + boolean Value (True/False). However, paddle.shape(x) == paddle.shape(y) is + an element-wise comparison. The difference can cause dy2stat error. So we + create this function to handle the difference. + + Args: + left: variable + *args: compare_op(str), variable, compare_op(str), variable, where + compare_op means "<", ">", "==", "!=", etc. + Returns: + If the variables to compare are NOT Paddle Variables, we will return as + Python like "a op1 b and b op2 c and ... ". + If the variables to compare are Paddle Variables, we will do elementwise + comparsion first and then reduce to a boolean whose numel is 1. + + """ + args_len = len(args) + assert args_len >= 2, "convert_shape_compare needs at least one right compare variable" + assert args_len % 2 == 0, "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..." + num_cmp = args_len // 2 + if isinstance(left, Variable): + + def reduce_compare(x, op_str, y): + element_wise_result = eval("x " + op_str + " y") + if op_str == "!=": + return reduce_any(element_wise_result) + elif op_str == "is" or op_str == "is not" or op_str == "in" or op_str == "not in": + return element_wise_result + else: + return reduce_all(element_wise_result) + + final_result = reduce_compare(left, args[0], args[1]) + for i in range(1, num_cmp): + cmp_left = args[i * 2 - 1] + cmp_op = args[i * 2] + cmp_right = args[i * 2 + 1] + cur_result = reduce_compare(cmp_left, cmp_op, cmp_right) + final_result = convert_logical_and(lambda: final_result, + lambda: cur_result) + return final_result + else: + cmp_left = left + final_result = None + for i in range(num_cmp): + cmp_op = args[i * 2] + cmp_right = args[i * 2 + 1] + cur_result = eval("cmp_left " + cmp_op + " cmp_right") + if final_result is None: + final_result = cur_result + else: + final_result = final_result and cur_result + + if final_result is False: + return False + cmp_left = cmp_right + return final_result + + def cast_bool_if_necessary(var): assert isinstance(var, Variable) if convert_dtype(var.dtype) not in ['bool']: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py index b7aa808801797c0d1c50adfa127b97916d0d6d31..8470e895dd3c89c27ca5a2ed2e95763d37030d1b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py @@ -17,6 +17,23 @@ from __future__ import print_function import gast from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +cmpop_type_to_str = { + gast.Eq: "==", + gast.NotEq: "!=", + gast.Lt: "<", + gast.LtE: "<=", + gast.Gt: ">", + gast.GtE: ">=", + gast.Is: "is", + gast.IsNot: "is not", + gast.In: "in", + gast.NotIn: "not in" +} + + +def cmpop_node_to_str(node): + return cmpop_type_to_str[type(node)] + class LogicalTransformer(gast.NodeTransformer): """ @@ -47,6 +64,29 @@ class LogicalTransformer(gast.NodeTransformer): return new_node return node + def visit_Compare(self, node): + self.generic_visit(node) + left_str = ast_to_source_code(node.left).strip() + if left_str.startswith("paddle.jit.dy2static.convert_var_shape"): + # check left and comparators are all converted var shape + compare_arg_strs = left_str + for i, comparator in enumerate(node.comparators): + comparator_str = ast_to_source_code(comparator).strip() + if not comparator_str.startswith( + "paddle.jit.dy2static.convert_var_shape"): + return node + op_str = cmpop_node_to_str(node.ops[i]) + compare_arg_strs += (", '" + op_str + "', " + comparator_str) + + # Now all left and comparators are converted shape + # Replace some comparsion operation because of difference between + # Python and Paddle + new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format( + compare_arg_strs) + new_node = gast.parse(new_node_str).body[0].value + return new_node + return node + def visit_BoolOp(self, node): self.generic_visit(node) if isinstance(node.op, gast.And): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 31de609e9fc41b8ebea30014e3ba524e0b85236a..1fd4e5b6c7f17e8a31c00ccb35e5495f06d653e1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -152,6 +152,17 @@ class TensorShapeTransformer(gast.NodeTransformer): setattr(parent_node, field, create_convert_shape_node(var_shape_node)) break + # Some child_node may be in a list such as gast.Compare + if isinstance(value, list): + has_converted_shape = False + for i, v in enumerate(value): + if child_node is v: + value[i] = create_convert_shape_node( + var_shape_node) + has_converted_shape = True + break + if has_converted_shape: + break return need_transformed def _used_by_paddle_api(self, node): diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index e6953e9ef255a63b2994cccecde824d603140950..fe60c24ff36ec7d3abd7ed1ae54217c2a1f310c6 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -865,30 +865,30 @@ class Layer(core.Layer): pass def __call__(self, *inputs, **kwargs): - for forward_pre_hook in self._forward_pre_hooks.values(): - hook_result = forward_pre_hook(self, inputs) - if hook_result is not None: - if not isinstance(hook_result, tuple): - hook_result = (hook_result, ) - inputs = hook_result - - if not self._built: - with program_desc_tracing_guard(False): - self._build_once(*inputs, **kwargs) - if parallel_helper._is_data_parallel_mode(): - parallel_helper._broadcast_parameters( - self._parameters.values()) - self._built = True - with param_guard(self._parameters), param_guard(self._buffers): + for forward_pre_hook in self._forward_pre_hooks.values(): + hook_result = forward_pre_hook(self, inputs) + if hook_result is not None: + if not isinstance(hook_result, tuple): + hook_result = (hook_result, ) + inputs = hook_result + + if not self._built: + with program_desc_tracing_guard(False): + self._build_once(*inputs, **kwargs) + if parallel_helper._is_data_parallel_mode(): + parallel_helper._broadcast_parameters( + self._parameters.values()) + self._built = True + outputs = self.forward(*inputs, **kwargs) - for forward_post_hook in self._forward_post_hooks.values(): - hook_result = forward_post_hook(self, inputs, outputs) - if hook_result is not None: - outputs = hook_result + for forward_post_hook in self._forward_post_hooks.values(): + hook_result = forward_post_hook(self, inputs, outputs) + if hook_result is not None: + outputs = hook_result - return outputs + return outputs def forward(self, *inputs, **kwargs): """ @@ -1083,7 +1083,15 @@ class Layer(core.Layer): # value via `assign`. if type(value) == framework.Variable: from paddle import assign - assign(value, _buffers[name]) + # Note(zhhsplendid): the condition below happens in PaddleGan model, + # but should all non-Variable _buffers[name] be re-assign? We + # should consider it in the future. I current wrote this as + # conservative code. + if _buffers[name] is None or type(_buffers[ + name]) == core.VarBase: + _buffers[name] = assign(value) + else: + assign(value, _buffers[name]) elif value is not None: raise TypeError( "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'" diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..16ed8670da4bc8362c783694974cee17a47ed477 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import unittest + + +class TestConvertShapeCompare(unittest.TestCase): + def test_non_variable(self): + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3), + True) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3), + False) + + def error_func(): + """ + Function used to test that comparison doesn't run after first False + """ + raise ValueError("Used for test") + + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare( + 1, ">", 2, "<=", lambda: error_func()), False) + + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "in", + [1, 2, 3]), True) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "not in", + [1, 2, 3]), False) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3), + False) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is not", + [1, 2, 3]), True) + + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare([1, 2], "==", [1, 2], + "!=", [1, 2, 3]), True) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare([1, 2], "!=", [1, 2, 3], + "==", [1, 2]), False) + + def test_variable(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.static.data(name='x', shape=[3, 2], dtype='float32') + y = paddle.static.data(name='y', shape=[3, 2], dtype='float32') + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is not", + y), True) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(x, "is not", x, + "is not", y), False) + self.assertEqual( + paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y), + False) + + eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y) + not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y) + long_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", x, + "!=", y) + + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + x_y_eq_out = exe.run(feed={ + "x": np.ones([3, 2]).astype(np.float32), + "y": np.ones([3, 2]).astype(np.float32) + }, + fetch_list=[eq_out, not_eq_out, long_eq_out]) + np.testing.assert_array_equal( + np.array(x_y_eq_out), np.array([[True], [False], [False]])) + + set_a_zero = np.ones([3, 2]).astype(np.float32) + set_a_zero[0][0] = 0.0 + x_y_not_eq_out = exe.run( + feed={ + "x": np.ones([3, 2]).astype(np.float32), + "y": set_a_zero + }, + fetch_list=[eq_out, not_eq_out, long_eq_out]) + np.testing.assert_array_equal( + np.array(x_y_not_eq_out), np.array([[False], [True], [True]])) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logical.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logical.py index 665e3f520ec978f5541456b852eeb61cf93442fa..c7193eb2a77bc85429e562000fe4f4101aff1a99 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logical.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logical.py @@ -18,11 +18,13 @@ from __future__ import print_function import unittest +import gast import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import ProgramTranslator +from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import cmpop_node_to_str program_translator = ProgramTranslator() @@ -149,6 +151,26 @@ def test_logical_not_and_or(x): return x +@paddle.jit.to_static +def test_shape_equal(x): + x = paddle.to_tensor(x) + y = paddle.zeros([1, 2, 3]) + if x.shape == y.shape: + return y + else: + return paddle.ones([1, 2, 3]) + + +@paddle.jit.to_static +def test_shape_not_equal(x): + x = paddle.to_tensor(x) + y = paddle.zeros([1, 2, 3]) + if x.shape != y.shape: + return y + else: + return paddle.ones([1, 2, 3]) + + class TestLogicalBase(unittest.TestCase): def setUp(self): self.input = np.array([3]).astype('int32') @@ -224,5 +246,35 @@ class TestLogicalNotAndOr(TestLogicalNot): self.dygraph_func = test_logical_not_and_or +class TestShapeEqual(TestLogicalNot): + def _set_test_func(self): + self.input = np.ones([1, 2, 3]).astype('float32') + self.dygraph_func = test_shape_equal + + +class TestShapeNotEqual(TestLogicalNot): + def _set_test_func(self): + self.input = np.ones([1, 2, 3]).astype('float32') + self.dygraph_func = test_shape_not_equal + + +class TestCmpopNodeToStr(unittest.TestCase): + def test_exception(self): + with self.assertRaises(KeyError): + cmpop_node_to_str(gast.Or()) + + def test_expected_result(self): + self.assertEqual(cmpop_node_to_str(gast.Eq()), "==") + self.assertEqual(cmpop_node_to_str(gast.NotEq()), "!=") + self.assertEqual(cmpop_node_to_str(gast.Lt()), "<") + self.assertEqual(cmpop_node_to_str(gast.LtE()), "<=") + self.assertEqual(cmpop_node_to_str(gast.Gt()), ">") + self.assertEqual(cmpop_node_to_str(gast.GtE()), ">=") + self.assertEqual(cmpop_node_to_str(gast.Is()), "is") + self.assertEqual(cmpop_node_to_str(gast.IsNot()), "is not") + self.assertEqual(cmpop_node_to_str(gast.In()), "in") + self.assertEqual(cmpop_node_to_str(gast.NotIn()), "not in") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 443c7234454819dccba30c46f6a5191360774691..fcf6a10974f60aa94ce09aa3058160b7b44b5b64 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -22,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS @@ -29,6 +30,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l __all__ = [ 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or', - 'convert_pop', 'convert_print', 'convert_var_dtype', 'convert_var_shape', - 'convert_while_loop' + 'convert_pop', 'convert_print', 'convert_shape_compare', + 'convert_var_dtype', 'convert_var_shape', 'convert_while_loop' ]