未验证 提交 32c139d3 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix PaddleGan Deoldify Model Dy2stat Problems (#29226) (#29281)

Cherry-pick of PR #29226
上级 34df32d6
...@@ -17,7 +17,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static ...@@ -17,7 +17,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static
from paddle.fluid.framework import core, Variable from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, Print from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import array_length, array_read, array_write, create_array from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
...@@ -272,6 +272,67 @@ def convert_var_shape(x): ...@@ -272,6 +272,67 @@ def convert_var_shape(x):
return x.shape return x.shape
def convert_shape_compare(left, *args):
"""
A function handles comparison difference between Paddle and Python.
For example, if x and y are Tensors, x.shape == y.shape will return single
boolean Value (True/False). However, paddle.shape(x) == paddle.shape(y) is
an element-wise comparison. The difference can cause dy2stat error. So we
create this function to handle the difference.
Args:
left: variable
*args: compare_op(str), variable, compare_op(str), variable, where
compare_op means "<", ">", "==", "!=", etc.
Returns:
If the variables to compare are NOT Paddle Variables, we will return as
Python like "a op1 b and b op2 c and ... ".
If the variables to compare are Paddle Variables, we will do elementwise
comparsion first and then reduce to a boolean whose numel is 1.
"""
args_len = len(args)
assert args_len >= 2, "convert_shape_compare needs at least one right compare variable"
assert args_len % 2 == 0, "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..."
num_cmp = args_len // 2
if isinstance(left, Variable):
def reduce_compare(x, op_str, y):
element_wise_result = eval("x " + op_str + " y")
if op_str == "!=":
return reduce_any(element_wise_result)
elif op_str == "is" or op_str == "is not" or op_str == "in" or op_str == "not in":
return element_wise_result
else:
return reduce_all(element_wise_result)
final_result = reduce_compare(left, args[0], args[1])
for i in range(1, num_cmp):
cmp_left = args[i * 2 - 1]
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = reduce_compare(cmp_left, cmp_op, cmp_right)
final_result = convert_logical_and(lambda: final_result,
lambda: cur_result)
return final_result
else:
cmp_left = left
final_result = None
for i in range(num_cmp):
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = eval("cmp_left " + cmp_op + " cmp_right")
if final_result is None:
final_result = cur_result
else:
final_result = final_result and cur_result
if final_result is False:
return False
cmp_left = cmp_right
return final_result
def cast_bool_if_necessary(var): def cast_bool_if_necessary(var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']: if convert_dtype(var.dtype) not in ['bool']:
......
...@@ -17,6 +17,23 @@ from __future__ import print_function ...@@ -17,6 +17,23 @@ from __future__ import print_function
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
cmpop_type_to_str = {
gast.Eq: "==",
gast.NotEq: "!=",
gast.Lt: "<",
gast.LtE: "<=",
gast.Gt: ">",
gast.GtE: ">=",
gast.Is: "is",
gast.IsNot: "is not",
gast.In: "in",
gast.NotIn: "not in"
}
def cmpop_node_to_str(node):
return cmpop_type_to_str[type(node)]
class LogicalTransformer(gast.NodeTransformer): class LogicalTransformer(gast.NodeTransformer):
""" """
...@@ -47,6 +64,29 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -47,6 +64,29 @@ class LogicalTransformer(gast.NodeTransformer):
return new_node return new_node
return node return node
def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("paddle.jit.dy2static.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith(
"paddle.jit.dy2static.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)
# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.And): if isinstance(node.op, gast.And):
......
...@@ -152,6 +152,17 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -152,6 +152,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
setattr(parent_node, field, setattr(parent_node, field,
create_convert_shape_node(var_shape_node)) create_convert_shape_node(var_shape_node))
break break
# Some child_node may be in a list such as gast.Compare
if isinstance(value, list):
has_converted_shape = False
for i, v in enumerate(value):
if child_node is v:
value[i] = create_convert_shape_node(
var_shape_node)
has_converted_shape = True
break
if has_converted_shape:
break
return need_transformed return need_transformed
def _used_by_paddle_api(self, node): def _used_by_paddle_api(self, node):
......
...@@ -865,30 +865,30 @@ class Layer(core.Layer): ...@@ -865,30 +865,30 @@ class Layer(core.Layer):
pass pass
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
with param_guard(self._parameters), param_guard(self._buffers): with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs) outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values(): for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs) hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None: if hook_result is not None:
outputs = hook_result outputs = hook_result
return outputs return outputs
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
""" """
...@@ -1083,7 +1083,15 @@ class Layer(core.Layer): ...@@ -1083,7 +1083,15 @@ class Layer(core.Layer):
# value via `assign`. # value via `assign`.
if type(value) == framework.Variable: if type(value) == framework.Variable:
from paddle import assign from paddle import assign
assign(value, _buffers[name]) # Note(zhhsplendid): the condition below happens in PaddleGan model,
# but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as
# conservative code.
if _buffers[name] is None or type(_buffers[
name]) == core.VarBase:
_buffers[name] = assign(value)
else:
assign(value, _buffers[name])
elif value is not None: elif value is not None:
raise TypeError( raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'" "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import unittest
class TestConvertShapeCompare(unittest.TestCase):
def test_non_variable(self):
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3),
True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3),
False)
def error_func():
"""
Function used to test that comparison doesn't run after first False
"""
raise ValueError("Used for test")
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, ">", 2, "<=", lambda: error_func()), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "in",
[1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "not in",
[1, 2, 3]), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3),
False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is not",
[1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare([1, 2], "==", [1, 2],
"!=", [1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare([1, 2], "!=", [1, 2, 3],
"==", [1, 2]), False)
def test_variable(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is not",
y), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is not", x,
"is not", y), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y),
False)
eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y)
not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y)
long_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", x,
"!=", y)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_y_eq_out = exe.run(feed={
"x": np.ones([3, 2]).astype(np.float32),
"y": np.ones([3, 2]).astype(np.float32)
},
fetch_list=[eq_out, not_eq_out, long_eq_out])
np.testing.assert_array_equal(
np.array(x_y_eq_out), np.array([[True], [False], [False]]))
set_a_zero = np.ones([3, 2]).astype(np.float32)
set_a_zero[0][0] = 0.0
x_y_not_eq_out = exe.run(
feed={
"x": np.ones([3, 2]).astype(np.float32),
"y": set_a_zero
},
fetch_list=[eq_out, not_eq_out, long_eq_out])
np.testing.assert_array_equal(
np.array(x_y_not_eq_out), np.array([[False], [True], [True]]))
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
...@@ -18,11 +18,13 @@ from __future__ import print_function ...@@ -18,11 +18,13 @@ from __future__ import print_function
import unittest import unittest
import gast
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import cmpop_node_to_str
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -149,6 +151,26 @@ def test_logical_not_and_or(x): ...@@ -149,6 +151,26 @@ def test_logical_not_and_or(x):
return x return x
@paddle.jit.to_static
def test_shape_equal(x):
x = paddle.to_tensor(x)
y = paddle.zeros([1, 2, 3])
if x.shape == y.shape:
return y
else:
return paddle.ones([1, 2, 3])
@paddle.jit.to_static
def test_shape_not_equal(x):
x = paddle.to_tensor(x)
y = paddle.zeros([1, 2, 3])
if x.shape != y.shape:
return y
else:
return paddle.ones([1, 2, 3])
class TestLogicalBase(unittest.TestCase): class TestLogicalBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.array([3]).astype('int32') self.input = np.array([3]).astype('int32')
...@@ -224,5 +246,35 @@ class TestLogicalNotAndOr(TestLogicalNot): ...@@ -224,5 +246,35 @@ class TestLogicalNotAndOr(TestLogicalNot):
self.dygraph_func = test_logical_not_and_or self.dygraph_func = test_logical_not_and_or
class TestShapeEqual(TestLogicalNot):
def _set_test_func(self):
self.input = np.ones([1, 2, 3]).astype('float32')
self.dygraph_func = test_shape_equal
class TestShapeNotEqual(TestLogicalNot):
def _set_test_func(self):
self.input = np.ones([1, 2, 3]).astype('float32')
self.dygraph_func = test_shape_not_equal
class TestCmpopNodeToStr(unittest.TestCase):
def test_exception(self):
with self.assertRaises(KeyError):
cmpop_node_to_str(gast.Or())
def test_expected_result(self):
self.assertEqual(cmpop_node_to_str(gast.Eq()), "==")
self.assertEqual(cmpop_node_to_str(gast.NotEq()), "!=")
self.assertEqual(cmpop_node_to_str(gast.Lt()), "<")
self.assertEqual(cmpop_node_to_str(gast.LtE()), "<=")
self.assertEqual(cmpop_node_to_str(gast.Gt()), ">")
self.assertEqual(cmpop_node_to_str(gast.GtE()), ">=")
self.assertEqual(cmpop_node_to_str(gast.Is()), "is")
self.assertEqual(cmpop_node_to_str(gast.IsNot()), "is not")
self.assertEqual(cmpop_node_to_str(gast.In()), "in")
self.assertEqual(cmpop_node_to_str(gast.NotIn()), "not in")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -22,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical ...@@ -22,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS
...@@ -29,6 +30,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l ...@@ -29,6 +30,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l
__all__ = [ __all__ = [
'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len',
'convert_logical_and', 'convert_logical_not', 'convert_logical_or', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or',
'convert_pop', 'convert_print', 'convert_var_dtype', 'convert_var_shape', 'convert_pop', 'convert_print', 'convert_shape_compare',
'convert_while_loop' 'convert_var_dtype', 'convert_var_shape', 'convert_while_loop'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册