未验证 提交 b8d493df 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Static] Refactor param_guard logic of @to_static (#32867)

* Add param_guard in ParameterList to support @to_static

* Refactor param_guard of @to_static

* fix unittest failed

* add more unittest
上级 59b74ee7
...@@ -63,35 +63,50 @@ _functional_dygraph_context_manager = None ...@@ -63,35 +63,50 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
if not framework.in_dygraph_mode() and parameters: if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy() origin_parameters = parameters.copy()
for name, var_base in parameters.items(): for name, var_base in parameters.items():
if isinstance(var_base, list):
new_var = [_convert_into_variable(var) for var in var_base]
else:
new_var = _convert_into_variable(var_base)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
else:
yield
def _convert_into_variable(var_base):
"""
Convert Varbase into Variable.
"""
if isinstance(var_base, core.VarBase): if isinstance(var_base, core.VarBase):
# Check whether has been created before.
new_var = var_base.block._find_var_recursive(var_base.name)
if new_var is not None:
assert isinstance(new_var, framework.Variable)
# Convert ParamBase into Parameter with same attributes in dy2stat. # Convert ParamBase into Parameter with same attributes in dy2stat.
if isinstance(var_base, framework.ParamBase): elif isinstance(var_base, framework.ParamBase):
new_var = var_base._to_static_var(to_parameter=True) new_var = var_base._to_static_var(to_parameter=True)
else: else:
# Check whether has been created before. # Note(Aurelius84): Convert VarBase in self._buffers into Variable with
if var_base.name in var_base.block.vars:
new_var = var_base.block.vars[var_base.name]
# Note(Aurelius84): Convert VarBase in self._buffers into Variabe with
# same attributes and set persistable=True to allow saving this var. # same attributes and set persistable=True to allow saving this var.
# Because users can create a VarBase in `__init__` like a # Because users can create a VarBase in `__init__` like a
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring. # and necessary for inferring. It will be pruned if it's not necessary for inferring.
else:
# But if its shape is empty while created from `create_variable()`, we consider this buffer # But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api. # non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(var_base.shape) > 0 is_persistable = len(var_base.shape) > 0
new_var = var_base._to_static_var( new_var = var_base._to_static_var(
to_parameter=False, persistable=is_persistable) to_parameter=False, persistable=is_persistable)
parameters[name] = new_var return new_var
yield
parameters.update(origin_parameters)
else: else:
yield return var_base
def enabled(): def enabled():
......
...@@ -873,6 +873,10 @@ class Layer(core.Layer): ...@@ -873,6 +873,10 @@ class Layer(core.Layer):
pass pass
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with param_guard(self._parameters), param_guard(self._buffers): with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values(): for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs) hook_result = forward_pre_hook(self, inputs)
......
...@@ -108,6 +108,8 @@ def monkey_patch_varbase(): ...@@ -108,6 +108,8 @@ def monkey_patch_varbase():
if to_parameter or isinstance(self, ParamBase): if to_parameter or isinstance(self, ParamBase):
del attr_kwargs['persistable'] del attr_kwargs['persistable']
# NOTE(Aurelius84): All parameters should be placed into global block.
attr_kwargs['block'] = attr_kwargs['block'].program.global_block()
static_var = Parameter(**attr_kwargs) static_var = Parameter(**attr_kwargs)
else: else:
static_var = Variable(**attr_kwargs) static_var = Variable(**attr_kwargs)
......
...@@ -3158,13 +3158,21 @@ class Block(object): ...@@ -3158,13 +3158,21 @@ class Block(object):
if attrs else {}, if attrs else {},
kwargs.get("stop_gradient", False)) kwargs.get("stop_gradient", False))
else: else:
from paddle.fluid.dygraph.base import param_guard
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
# NOTE(Aurelius84): In case of @to_static, all VarBase(s) should
# be converted into Variable(s) with same name and block location.
# This is ONE and ONLY logic of type transformation of dy2static.
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
with param_guard(inputs), param_guard(outputs):
op = Operator( op = Operator(
block=self, block=self,
desc=op_desc, desc=op_desc,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None), inputs=inputs,
outputs=kwargs.get("outputs", None), outputs=outputs,
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", None))
self.ops.append(op) self.ops.append(op)
......
...@@ -580,8 +580,12 @@ def assign(input, output=None): ...@@ -580,8 +580,12 @@ def assign(input, output=None):
input = numpy.array([input]) input = numpy.array([input])
elif isinstance(input, (list, tuple)): elif isinstance(input, (list, tuple)):
input = numpy.array(input) input = numpy.array(input)
# NOTE(Aurelius84): Why we judge core.VarBase?
if isinstance(input, Variable): # In case of @to_static, a VarBase can be as input of `assign`,
# but in_dygraph_mode()==False under @to_static, which means
# isinstance(VarBase, Variable) == False. It will cause return None
# after this api.
if isinstance(input, (Variable, core.VarBase)):
check_dtype(input.dtype, 'input', [ check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', 'bool' 'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', 'bool'
], 'assign', '(When the type of input in assign is Variable.)') ], 'assign', '(When the type of input in assign is Variable.)')
......
...@@ -91,5 +91,81 @@ class TestParameterList(unittest.TestCase): ...@@ -91,5 +91,81 @@ class TestParameterList(unittest.TestCase):
static_loss)) static_loss))
class NetWithRawParamList(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(NetWithRawParamList, self).__init__()
weight = self.add_parameter('w',
self.create_parameter([in_size, out_size]))
bias = self.add_parameter(
'b', self.create_parameter(
[out_size], is_bias=True))
self.params = [weight]
self.bias_dict = {'b': bias}
@to_static
def forward(self, x):
out = paddle.matmul(x, self.params[0])
out = paddle.add(out, self.bias_dict['b'])
out = paddle.tanh(out)
return out
class TestRawParameterList(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.iter_num = 5
self.prog_trans = ProgramTranslator()
def init_net(self):
self.net = NetWithRawParamList(10, 3)
def train(self, to_static):
paddle.seed(self.seed)
np.random.seed(self.seed)
self.prog_trans.enable(to_static)
self.init_net()
sgd = paddle.optimizer.SGD(0.1, parameters=self.net.parameters())
for batch_id in range(self.iter_num):
x = paddle.rand([4, 10], dtype='float32')
out = self.net(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def test_parameter_list(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss,
static_loss))
class NetWithSubLayerParamList(paddle.nn.Layer):
def __init__(self, sub_layer):
super(NetWithSubLayerParamList, self).__init__()
self.sub_layer = sub_layer
self.params = [sub_layer.weight]
self.bias_dict = {'b': sub_layer.bias}
@to_static
def forward(self, x):
out = paddle.matmul(x, self.params[0])
out = paddle.add(out, self.bias_dict['b'])
out = paddle.tanh(out)
return out
class TestSubLayerParameterList(TestRawParameterList):
def init_net(self):
fc = paddle.nn.Linear(10, 3)
self.net = NetWithSubLayerParamList(fc)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册