未验证 提交 b8d493df 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Static] Refactor param_guard logic of @to_static (#32867)

* Add param_guard in ParameterList to support @to_static

* Refactor param_guard of @to_static

* fix unittest failed

* add more unittest
上级 59b74ee7
......@@ -63,37 +63,52 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager
def param_guard(parameters):
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
# Note: parameters is a reference of self._parameters or self._buffers
if not framework.in_dygraph_mode() and parameters:
if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
if isinstance(var_base, core.VarBase):
# Convert ParamBase into Parameter with same attributes in dy2stat.
if isinstance(var_base, framework.ParamBase):
new_var = var_base._to_static_var(to_parameter=True)
else:
# Check whether has been created before.
if var_base.name in var_base.block.vars:
new_var = var_base.block.vars[var_base.name]
# Note(Aurelius84): Convert VarBase in self._buffers into Variabe with
# same attributes and set persistable=True to allow saving this var.
# Because users can create a VarBase in `__init__` like a
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
else:
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(var_base.shape) > 0
new_var = var_base._to_static_var(
to_parameter=False, persistable=is_persistable)
parameters[name] = new_var
if isinstance(var_base, list):
new_var = [_convert_into_variable(var) for var in var_base]
else:
new_var = _convert_into_variable(var_base)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
else:
yield
def _convert_into_variable(var_base):
"""
Convert Varbase into Variable.
"""
if isinstance(var_base, core.VarBase):
# Check whether has been created before.
new_var = var_base.block._find_var_recursive(var_base.name)
if new_var is not None:
assert isinstance(new_var, framework.Variable)
# Convert ParamBase into Parameter with same attributes in dy2stat.
elif isinstance(var_base, framework.ParamBase):
new_var = var_base._to_static_var(to_parameter=True)
else:
# Note(Aurelius84): Convert VarBase in self._buffers into Variable with
# same attributes and set persistable=True to allow saving this var.
# Because users can create a VarBase in `__init__` like a
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(var_base.shape) > 0
new_var = var_base._to_static_var(
to_parameter=False, persistable=is_persistable)
return new_var
else:
return var_base
def enabled():
"""
This function checks whether the program runs in dynamic graph mode or not.
......@@ -664,7 +679,7 @@ def to_variable(value, name=None, zero_copy=None, dtype=None):
if isinstance(framework._current_expected_place(),
framework.core.CPUPlace):
#TODO(zhiqiu): we found two problems when enable zero_copy on CPUPlace.
# (1): eigen requires 16-bytes alignments, but the data of numpy array may not statisfy.
# (1): eigen requires 16-bytes alignments, but the data of numpy array may not statisfy.
# Details: https://eigen.tuxfamily.org/dox/group__TopicUnalignedArrayAssert.html
# (2): when used in flask framework, it may result in hang.
# Details: https://github.com/PaddlePaddle/Paddle/issues/26635
......
......@@ -873,6 +873,10 @@ class Layer(core.Layer):
pass
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
......
......@@ -86,7 +86,7 @@ def monkey_patch_varbase():
"""
# Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
# Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
# It will fail. So, for propery in dygraph only, should not let it getattr(self, attr, None).
attr_not_need_keys = ['grad']
if isinstance(self, ParamBase):
......@@ -108,6 +108,8 @@ def monkey_patch_varbase():
if to_parameter or isinstance(self, ParamBase):
del attr_kwargs['persistable']
# NOTE(Aurelius84): All parameters should be placed into global block.
attr_kwargs['block'] = attr_kwargs['block'].program.global_block()
static_var = Parameter(**attr_kwargs)
else:
static_var = Variable(**attr_kwargs)
......
......@@ -3158,14 +3158,22 @@ class Block(object):
if attrs else {},
kwargs.get("stop_gradient", False))
else:
from paddle.fluid.dygraph.base import param_guard
op_desc = self.desc.append_op()
op = Operator(
block=self,
desc=op_desc,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
# NOTE(Aurelius84): In case of @to_static, all VarBase(s) should
# be converted into Variable(s) with same name and block location.
# This is ONE and ONLY logic of type transformation of dy2static.
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=self,
desc=op_desc,
type=kwargs.get("type", None),
inputs=inputs,
outputs=outputs,
attrs=kwargs.get("attrs", None))
self.ops.append(op)
......
......@@ -580,8 +580,12 @@ def assign(input, output=None):
input = numpy.array([input])
elif isinstance(input, (list, tuple)):
input = numpy.array(input)
if isinstance(input, Variable):
# NOTE(Aurelius84): Why we judge core.VarBase?
# In case of @to_static, a VarBase can be as input of `assign`,
# but in_dygraph_mode()==False under @to_static, which means
# isinstance(VarBase, Variable) == False. It will cause return None
# after this api.
if isinstance(input, (Variable, core.VarBase)):
check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', 'bool'
], 'assign', '(When the type of input in assign is Variable.)')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -91,5 +91,81 @@ class TestParameterList(unittest.TestCase):
static_loss))
class NetWithRawParamList(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(NetWithRawParamList, self).__init__()
weight = self.add_parameter('w',
self.create_parameter([in_size, out_size]))
bias = self.add_parameter(
'b', self.create_parameter(
[out_size], is_bias=True))
self.params = [weight]
self.bias_dict = {'b': bias}
@to_static
def forward(self, x):
out = paddle.matmul(x, self.params[0])
out = paddle.add(out, self.bias_dict['b'])
out = paddle.tanh(out)
return out
class TestRawParameterList(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.iter_num = 5
self.prog_trans = ProgramTranslator()
def init_net(self):
self.net = NetWithRawParamList(10, 3)
def train(self, to_static):
paddle.seed(self.seed)
np.random.seed(self.seed)
self.prog_trans.enable(to_static)
self.init_net()
sgd = paddle.optimizer.SGD(0.1, parameters=self.net.parameters())
for batch_id in range(self.iter_num):
x = paddle.rand([4, 10], dtype='float32')
out = self.net(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def test_parameter_list(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss,
static_loss))
class NetWithSubLayerParamList(paddle.nn.Layer):
def __init__(self, sub_layer):
super(NetWithSubLayerParamList, self).__init__()
self.sub_layer = sub_layer
self.params = [sub_layer.weight]
self.bias_dict = {'b': sub_layer.bias}
@to_static
def forward(self, x):
out = paddle.matmul(x, self.params[0])
out = paddle.add(out, self.bias_dict['b'])
out = paddle.tanh(out)
return out
class TestSubLayerParameterList(TestRawParameterList):
def init_net(self):
fc = paddle.nn.Linear(10, 3)
self.net = NetWithSubLayerParamList(fc)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册