未验证 提交 aec05d81 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix PaddleGan Deoldify Model Dy2stat Problems (#29226)

This PR fixes several problems in dy2stat for Deoldify model in PaddleGan.

In model, software engineer wrote if x.shape == y.shape, the Tenser shape is a tuple in dygraph so the == returns True/False, but in static graph the == becomes element-wise comparison, which is a different behavior. In this PR we reduce the element-wise comparison result.

If software engineer write computations which uses parameters in hooks, the static graph can loss the parameter variable because we put param_guard at forward of a Layer. In this PR we made param_guard cover pre-hook and post-hook.

In PaddleGan, software engineer calculated some parameter values in __init__ by running some dygraph code. Those code also run during dy2stat. So some variables may be assign as a VarBase (Tensor) first and then Variable, which raised an error. We fixed the bug in this PR by handling the case.

TODO: We just added testcase for the 1. shape comparison. Should add test case for 2. and 3. But since we are chasing 2.0RC, I will do it in the near future PR
上级 fc80d2e0
......@@ -17,7 +17,7 @@ from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static
from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import Assert, Print
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
......@@ -272,6 +272,67 @@ def convert_var_shape(x):
return x.shape
def convert_shape_compare(left, *args):
"""
A function handles comparison difference between Paddle and Python.
For example, if x and y are Tensors, x.shape == y.shape will return single
boolean Value (True/False). However, paddle.shape(x) == paddle.shape(y) is
an element-wise comparison. The difference can cause dy2stat error. So we
create this function to handle the difference.
Args:
left: variable
*args: compare_op(str), variable, compare_op(str), variable, where
compare_op means "<", ">", "==", "!=", etc.
Returns:
If the variables to compare are NOT Paddle Variables, we will return as
Python like "a op1 b and b op2 c and ... ".
If the variables to compare are Paddle Variables, we will do elementwise
comparsion first and then reduce to a boolean whose numel is 1.
"""
args_len = len(args)
assert args_len >= 2, "convert_shape_compare needs at least one right compare variable"
assert args_len % 2 == 0, "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..."
num_cmp = args_len // 2
if isinstance(left, Variable):
def reduce_compare(x, op_str, y):
element_wise_result = eval("x " + op_str + " y")
if op_str == "!=":
return reduce_any(element_wise_result)
elif op_str == "is" or op_str == "is not" or op_str == "in" or op_str == "not in":
return element_wise_result
else:
return reduce_all(element_wise_result)
final_result = reduce_compare(left, args[0], args[1])
for i in range(1, num_cmp):
cmp_left = args[i * 2 - 1]
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = reduce_compare(cmp_left, cmp_op, cmp_right)
final_result = convert_logical_and(lambda: final_result,
lambda: cur_result)
return final_result
else:
cmp_left = left
final_result = None
for i in range(num_cmp):
cmp_op = args[i * 2]
cmp_right = args[i * 2 + 1]
cur_result = eval("cmp_left " + cmp_op + " cmp_right")
if final_result is None:
final_result = cur_result
else:
final_result = final_result and cur_result
if final_result is False:
return False
cmp_left = cmp_right
return final_result
def cast_bool_if_necessary(var):
assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']:
......
......@@ -17,6 +17,23 @@ from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
cmpop_type_to_str = {
gast.Eq: "==",
gast.NotEq: "!=",
gast.Lt: "<",
gast.LtE: "<=",
gast.Gt: ">",
gast.GtE: ">=",
gast.Is: "is",
gast.IsNot: "is not",
gast.In: "in",
gast.NotIn: "not in"
}
def cmpop_node_to_str(node):
return cmpop_type_to_str[type(node)]
class LogicalTransformer(gast.NodeTransformer):
"""
......@@ -47,6 +64,29 @@ class LogicalTransformer(gast.NodeTransformer):
return new_node
return node
def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("paddle.jit.dy2static.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith(
"paddle.jit.dy2static.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)
# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
......
......@@ -152,6 +152,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
setattr(parent_node, field,
create_convert_shape_node(var_shape_node))
break
# Some child_node may be in a list such as gast.Compare
if isinstance(value, list):
has_converted_shape = False
for i, v in enumerate(value):
if child_node is v:
value[i] = create_convert_shape_node(
var_shape_node)
has_converted_shape = True
break
if has_converted_shape:
break
return need_transformed
def _used_by_paddle_api(self, node):
......
......@@ -865,30 +865,30 @@ class Layer(core.Layer):
pass
def __call__(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
return outputs
return outputs
def forward(self, *inputs, **kwargs):
"""
......@@ -1083,7 +1083,15 @@ class Layer(core.Layer):
# value via `assign`.
if type(value) == framework.Variable:
from paddle import assign
assign(value, _buffers[name])
# Note(zhhsplendid): the condition below happens in PaddleGan model,
# but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as
# conservative code.
if _buffers[name] is None or type(_buffers[
name]) == core.VarBase:
_buffers[name] = assign(value)
else:
assign(value, _buffers[name])
elif value is not None:
raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import unittest
class TestConvertShapeCompare(unittest.TestCase):
def test_non_variable(self):
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3),
True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3),
False)
def error_func():
"""
Function used to test that comparison doesn't run after first False
"""
raise ValueError("Used for test")
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(
1, ">", 2, "<=", lambda: error_func()), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "in",
[1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "not in",
[1, 2, 3]), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3),
False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is not",
[1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare([1, 2], "==", [1, 2],
"!=", [1, 2, 3]), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare([1, 2], "!=", [1, 2, 3],
"==", [1, 2]), False)
def test_variable(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is not",
y), True)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is not", x,
"is not", y), False)
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y),
False)
eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y)
not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y)
long_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", x,
"!=", y)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_y_eq_out = exe.run(feed={
"x": np.ones([3, 2]).astype(np.float32),
"y": np.ones([3, 2]).astype(np.float32)
},
fetch_list=[eq_out, not_eq_out, long_eq_out])
np.testing.assert_array_equal(
np.array(x_y_eq_out), np.array([[True], [False], [False]]))
set_a_zero = np.ones([3, 2]).astype(np.float32)
set_a_zero[0][0] = 0.0
x_y_not_eq_out = exe.run(
feed={
"x": np.ones([3, 2]).astype(np.float32),
"y": set_a_zero
},
fetch_list=[eq_out, not_eq_out, long_eq_out])
np.testing.assert_array_equal(
np.array(x_y_not_eq_out), np.array([[False], [True], [True]]))
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
......@@ -18,11 +18,13 @@ from __future__ import print_function
import unittest
import gast
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import cmpop_node_to_str
program_translator = ProgramTranslator()
......@@ -149,6 +151,26 @@ def test_logical_not_and_or(x):
return x
@paddle.jit.to_static
def test_shape_equal(x):
x = paddle.to_tensor(x)
y = paddle.zeros([1, 2, 3])
if x.shape == y.shape:
return y
else:
return paddle.ones([1, 2, 3])
@paddle.jit.to_static
def test_shape_not_equal(x):
x = paddle.to_tensor(x)
y = paddle.zeros([1, 2, 3])
if x.shape != y.shape:
return y
else:
return paddle.ones([1, 2, 3])
class TestLogicalBase(unittest.TestCase):
def setUp(self):
self.input = np.array([3]).astype('int32')
......@@ -224,5 +246,35 @@ class TestLogicalNotAndOr(TestLogicalNot):
self.dygraph_func = test_logical_not_and_or
class TestShapeEqual(TestLogicalNot):
def _set_test_func(self):
self.input = np.ones([1, 2, 3]).astype('float32')
self.dygraph_func = test_shape_equal
class TestShapeNotEqual(TestLogicalNot):
def _set_test_func(self):
self.input = np.ones([1, 2, 3]).astype('float32')
self.dygraph_func = test_shape_not_equal
class TestCmpopNodeToStr(unittest.TestCase):
def test_exception(self):
with self.assertRaises(KeyError):
cmpop_node_to_str(gast.Or())
def test_expected_result(self):
self.assertEqual(cmpop_node_to_str(gast.Eq()), "==")
self.assertEqual(cmpop_node_to_str(gast.NotEq()), "!=")
self.assertEqual(cmpop_node_to_str(gast.Lt()), "<")
self.assertEqual(cmpop_node_to_str(gast.LtE()), "<=")
self.assertEqual(cmpop_node_to_str(gast.Gt()), ">")
self.assertEqual(cmpop_node_to_str(gast.GtE()), ">=")
self.assertEqual(cmpop_node_to_str(gast.Is()), "is")
self.assertEqual(cmpop_node_to_str(gast.IsNot()), "is not")
self.assertEqual(cmpop_node_to_str(gast.In()), "in")
self.assertEqual(cmpop_node_to_str(gast.NotIn()), "not in")
if __name__ == '__main__':
unittest.main()
......@@ -22,6 +22,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_logical_or #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS
......@@ -29,6 +30,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_l
__all__ = [
'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len',
'convert_logical_and', 'convert_logical_not', 'convert_logical_or',
'convert_pop', 'convert_print', 'convert_var_dtype', 'convert_var_shape',
'convert_while_loop'
'convert_pop', 'convert_print', 'convert_shape_compare',
'convert_var_dtype', 'convert_var_shape', 'convert_while_loop'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册