未验证 提交 dd339f4e 编写于 作者: 0 0x45f 提交者: GitHub

[CherryPick][Dy2St]Fix error when calling sublayer's non-forward func in dy2stat (#38418)

Fix error when calling sublayer's non-forward func in dy2stat
cherrypick: #37713、#37759、#37296、#38540、#37888
上级 1046636b
......@@ -93,10 +93,10 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
if in_dygraph_mode():
return
from .dygraph.dygraph_to_static.program_translator import in_declarative_mode
# NOTE: `in_declarative_mode` is used to determined whether this op is called under
# @declarative in transformation from dygrah to static layer. We add VarBase in
# expected_type to skip checking because varBase may be created and used in unusual way.
from .dygraph.base import in_declarative_mode
# Need a better design to be fix this.
if in_declarative_mode():
if not isinstance(expected_type, tuple):
......
......@@ -33,6 +33,17 @@ __all__ = [
'enabled', 'to_variable'
]
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return _in_declarative_mode_
def _switch_to_static_graph_(func):
def __impl__(*args, **kwargs):
......@@ -45,6 +56,16 @@ def _switch_to_static_graph_(func):
switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
tracer = framework._dygraph_tracer()
......@@ -63,7 +84,6 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager
def param_guard(parameters):
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
# Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy()
......
......@@ -248,20 +248,6 @@ def _remove_no_value_return_var(out):
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars):
return_var_ids = [id(var) for var in return_vars]
# NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors
# NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
# which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
true_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in true_args
]
false_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in false_args
]
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
......
......@@ -573,28 +573,6 @@ class StaticFunction(object):
return self._function_spec
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return _in_declarative_mode_
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
def _verify_init_in_dynamic_mode(class_instance):
"""
Verifies the instance is initialized in dynamic mode.
......@@ -658,6 +636,7 @@ class ConcreteProgram(object):
startup_program.random_seed = framework.default_startup_program(
).random_seed
from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_
with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True):
# 1. Adds `fluid.data` layers for input if needed
......
......@@ -1077,8 +1077,13 @@ def append_var_from_block_desc_static(block,
else:
lod_level = None
if var_desc.persistable():
current_block = block.program.global_block()
else:
current_block = block
vars_append.append(
block.create_var(
current_block.create_var(
name=var_desc.name(),
dtype=data_type,
type=var_type,
......
......@@ -31,7 +31,7 @@ from .. import unique_name
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, _convert_into_variable
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
......@@ -882,12 +882,7 @@ class Layer(core.Layer):
def _build_once(self, *args, **kwargs):
pass
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with param_guard(self._parameters), param_guard(self._buffers):
def _dygraph_call_func(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
......@@ -918,6 +913,9 @@ class Layer(core.Layer):
return outputs
def __call__(self, *inputs, **kwargs):
return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs):
"""
Defines the computation performed at every call.
......@@ -1096,6 +1094,8 @@ class Layer(core.Layer):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in self._parameters:
if in_declarative_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name]
if '_sub_layers' in self.__dict__:
_sub_layers = self.__dict__['_sub_layers']
......@@ -1104,6 +1104,8 @@ class Layer(core.Layer):
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
if in_declarative_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name]
return object.__getattribute__(self, name)
......@@ -1174,11 +1176,16 @@ class Layer(core.Layer):
# but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as
# conservative code.
if _buffers[name] is None or type(_buffers[
name]) == core.VarBase:
if in_declarative_mode() and _buffers[name] is None:
raise RuntimeError(
'In Dy2stat, self.{0} is a buffer and self.{0} is '
'not allowed to be set to Variable when self.{0} is None.'.
format(name))
elif _buffers[name] is None or type(
getattr(self, name)) == core.VarBase:
_buffers[name] = assign(value)
else:
assign(value, _buffers[name])
assign(value, getattr(self, name))
elif value is not None:
raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
......
......@@ -102,6 +102,41 @@ def select_input(inputs, mask):
return out
def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs
if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)
elif (isinstance(false_var, (support_ret_buildin_type)) and
isinstance(false_var, type(true_var))):
if false_var == true_var:
return false_var
else:
inputs = [
to_static_variable(false_var), to_static_variable(true_var)
]
# Deal with the situations like this: false_var is int and true_var is Variable
elif ((isinstance(false_var, support_ret_buildin_type) and
isinstance(true_var, Variable)) or
(isinstance(true_var, support_ret_buildin_type) and
isinstance(false_var, Variable))):
inputs = [to_static_variable(false_var), to_static_variable(true_var)]
warnings.warn(
"Return results from different branches in cond are not same type: "
"false_var returned by fasle_fn is '{}' and true_var of true_fn is "
"'{}'".format(type(false_var), type(true_var)))
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))
return select_input(inputs, mask)
def split_lod_tensor(input, mask, level=0):
"""
This function takes in an input that contains the complete lod information,
......@@ -2282,8 +2317,8 @@ class ConditionalBlock(object):
def copy_var_to_parent_block(var, layer_helper):
if var is None:
return None
if not isinstance(var, Variable):
return var
prog = layer_helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow"
......@@ -2466,7 +2501,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
format(e))
mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var : select_input([false_var, true_var], mask)
merge_func = lambda false_var, true_var : select_input_with_buildin_type([false_var, true_var], mask)
merged_output = map_structure(merge_func, false_output, true_output)
return merged_output
......
......@@ -205,6 +205,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
self.alpha = 10.
self.constant_vars = {}
@paddle.jit.to_static
def forward(self, input):
hidden_dim = input.shape[-1]
if hidden_dim != self.hidden_dim:
......@@ -340,3 +341,53 @@ def if_tensor_case(x):
x += 1
return x
def dyfunc_ifelse_ret_int1(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y, index
def dyfunc_ifelse_ret_int2(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y
def dyfunc_ifelse_ret_int3(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return index
else:
y = x[index] + 2
return y
def dyfunc_ifelse_ret_int4(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return 'unsupport ret'
else:
y = x[index] + 2
return y
......@@ -379,5 +379,74 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase):
net.forward.outputs
class CallNonForwardFuncNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncNet, self).__init__()
self.sub = CallNonForwardFuncSubNet()
@paddle.jit.to_static
def forward(self):
return self.sub.func()
class CallNonForwardFuncSubNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncSubNet, self).__init__()
self.a = paddle.to_tensor([1, 2])
def func(self):
x = self.a * 2
return x
class TestCallNonForwardFunc(unittest.TestCase):
def test_call_non_forward(self):
paddle.disable_static()
net = CallNonForwardFuncNet()
out = net()
self.assertEqual(out.numpy().tolist(), [2, 4])
paddle.enable_static()
class SetBuffersNet1(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet1, self).__init__()
self.a = paddle.to_tensor([1])
@paddle.jit.to_static
def forward(self):
self.a = self.a + 1
return self.a
class SetBuffersNet2(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet2, self).__init__()
self.b = paddle.to_tensor([2])
@paddle.jit.to_static
def forward(self):
self.b = None
self.b = paddle.to_tensor([3])
return self.b
class TestSetBuffers(unittest.TestCase):
def test_set_buffers1(self):
paddle.disable_static()
net = SetBuffersNet1()
out = net()
self.assertEqual(out.numpy().tolist(), [2])
paddle.jit.save(net, './SetBuffersNet1')
paddle.enable_static()
def test_set_buffers2(self):
paddle.disable_static()
net = SetBuffersNet2()
with self.assertRaises(RuntimeError):
out = net()
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -365,5 +365,63 @@ class TestNewVarCreateInOneBranch(unittest.TestCase):
self.assertEqual(paddle.jit.to_static(case_func)(True), -2)
class TestDy2StIfElseRetInt1(unittest.TestCase):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int1
self.out = self.get_dy2stat_out()
def get_dy2stat_out(self):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
ProgramTranslator().enable(False)
return out
def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], int)
class TestDy2StIfElseRetInt2(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int2
self.out = self.get_dy2stat_out()
def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], paddle.Tensor)
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int3
self.out = self.get_dy2stat_out()
def test_ast_to_func(self):
self.assertIsInstance(self.out, paddle.Tensor)
class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int4
def test_ast_to_func(self):
ProgramTranslator().enable(True)
with self.assertRaises(TypeError):
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
# Why need set `_in_declarative_mode_` here?
# In Dy2St we use `with _switch_declarative_mode_guard_()` to indicate
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册