未验证 提交 f59c666c 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Support Lazy initialization for nn.Layer (#44990)

* [Eager]Support Lazy initialization for nn.Lazyer
上级 0c98d2bf
...@@ -36,6 +36,7 @@ from .framework import disable_static # noqa: F401 ...@@ -36,6 +36,7 @@ from .framework import disable_static # noqa: F401
from .framework import enable_static # noqa: F401 from .framework import enable_static # noqa: F401
from .framework import in_dynamic_mode # noqa: F401 from .framework import in_dynamic_mode # noqa: F401
from .fluid.dataset import * # noqa: F401 from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyInit # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401 from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401 from .framework.dtype import uint8 # noqa: F401
...@@ -415,6 +416,7 @@ __all__ = [ # noqa ...@@ -415,6 +416,7 @@ __all__ = [ # noqa
'cumprod', 'cumprod',
'logcumsumexp', 'logcumsumexp',
'logit', 'logit',
'LazyInit',
'sign', 'sign',
'is_empty', 'is_empty',
'equal', 'equal',
......
...@@ -120,6 +120,10 @@ def monkey_patch_varbase(): ...@@ -120,6 +120,10 @@ def monkey_patch_varbase():
for attr in attr_keys: for attr in attr_keys:
attr_kwargs[attr] = getattr(self, attr, None) attr_kwargs[attr] = getattr(self, attr, None)
# If specify block, use it instead of self.block
if 'block' in kwargs:
attr_kwargs['block'] = kwargs['block']
attr_kwargs.update(kwargs) attr_kwargs.update(kwargs)
if to_parameter or isinstance(self, (ParamBase, EagerParamBase)): if to_parameter or isinstance(self, (ParamBase, EagerParamBase)):
......
...@@ -6795,6 +6795,17 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -6795,6 +6795,17 @@ class EagerParamBase(_core_eager_eagertensor):
self.is_distributed = kwargs.get('is_distributed', False) self.is_distributed = kwargs.get('is_distributed', False)
# self.block = default_main_program().global_block() # self.block = default_main_program().global_block()
self.init_func = None
def set_init_func(self, obj):
self.init_func = obj
@dygraph_only
def initialize(self):
assert self.init_func is not None, "Required self.init_func is not None, but received None."
self.init_func()
# clear function handle to release resource
self.init_func = None
@property @property
def trainable(self): def trainable(self):
......
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
from __future__ import print_function from __future__ import print_function
import math import math
import functools
from . import framework from . import framework
from . import core from . import core
from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place
from .lazy_init import lazy_guard
from .framework import program_guard
import numpy as np import numpy as np
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
...@@ -49,10 +52,32 @@ class Initializer(object): ...@@ -49,10 +52,32 @@ class Initializer(object):
pass pass
def __call__(self, param, block=None): def __call__(self, param, block=None):
if not lazy_guard().state:
return self.forward(param, block)
return self._lazy_init(param, block)
def forward(self, param, block=None):
"""Add corresponding initialization operations to the network """Add corresponding initialization operations to the network
""" """
raise NotImplementedError() raise NotImplementedError()
def _lazy_init(self, param, block=None):
# Apply lazy initialization
assert in_dygraph_mode()
new_block = lazy_guard().startup_program.global_block()
new_var = param._to_static_var(True, block=new_block)
# Record initializer operator
with lazy_guard():
self.forward(new_var, new_block)
lazy_guard().enable(clear_cache=False)
# Add hook function for initializing param in dygraph mode
func = functools.partial(self.forward, param, block)
param.set_init_func(func)
return param
def _check_block(self, block): def _check_block(self, block):
if block is None: if block is None:
block = default_main_program().global_block() block = default_main_program().global_block()
...@@ -121,7 +146,7 @@ class ConstantInitializer(Initializer): ...@@ -121,7 +146,7 @@ class ConstantInitializer(Initializer):
self._value = value self._value = value
self._force_cpu = force_cpu self._force_cpu = force_cpu
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with constant. """Initialize the input tensor with constant.
Args: Args:
...@@ -214,7 +239,7 @@ class UniformInitializer(Initializer): ...@@ -214,7 +239,7 @@ class UniformInitializer(Initializer):
self._diag_step = diag_step self._diag_step = diag_step
self._diag_val = diag_val self._diag_val = diag_val
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with Uniform distribution. """Initialize the input tensor with Uniform distribution.
Args: Args:
...@@ -316,7 +341,7 @@ class NormalInitializer(Initializer): ...@@ -316,7 +341,7 @@ class NormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with Normal distribution. """Initialize the input tensor with Normal distribution.
Args: Args:
...@@ -428,7 +453,7 @@ class TruncatedNormalInitializer(Initializer): ...@@ -428,7 +453,7 @@ class TruncatedNormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with TruncatedNormal distribution. """Initialize the input tensor with TruncatedNormal distribution.
Args: Args:
...@@ -562,7 +587,7 @@ class XavierInitializer(Initializer): ...@@ -562,7 +587,7 @@ class XavierInitializer(Initializer):
self._fan_out = fan_out self._fan_out = fan_out
self._seed = seed self._seed = seed
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with Xavier initialization. """Initialize the input tensor with Xavier initialization.
Args: Args:
...@@ -740,7 +765,7 @@ class MSRAInitializer(Initializer): ...@@ -740,7 +765,7 @@ class MSRAInitializer(Initializer):
self._negative_slope = negative_slope self._negative_slope = negative_slope
self._nonlinearity = nonlinearity self._nonlinearity = nonlinearity
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with MSRA initialization. """Initialize the input tensor with MSRA initialization.
Args: Args:
...@@ -901,7 +926,7 @@ class BilinearInitializer(Initializer): ...@@ -901,7 +926,7 @@ class BilinearInitializer(Initializer):
""" """
super(BilinearInitializer, self).__init__() super(BilinearInitializer, self).__init__()
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with Bilinear initialization. """Initialize the input tensor with Bilinear initialization.
Args: Args:
...@@ -1026,7 +1051,7 @@ class NumpyArrayInitializer(Initializer): ...@@ -1026,7 +1051,7 @@ class NumpyArrayInitializer(Initializer):
super(NumpyArrayInitializer, self).__init__() super(NumpyArrayInitializer, self).__init__()
self._value = value self._value = value
def __call__(self, var, block=None): def forward(self, var, block=None):
"""Initialize the input tensor with Numpy array. """Initialize the input tensor with Numpy array.
Args: Args:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import framework
__all__ = ["LazyInit"]
class LazyGuard(object):
"""
Guard Context to trigger switching mode between dygraph and static mode,
and holds the startup program resource.
"""
def __init__(self):
self._init_program()
self._state = False
self._tracer = None
self._in_guard = False
def enable(self, clear_cache=True):
"""
Switch into lazy mode.
NOTE(dev): This is a very low level API and not exposed for user.
"""
if self._state:
return
assert framework.in_dygraph_mode(
), "LazyInit.enable() is only available in dygraph mode."
self._state = True
if clear_cache:
self._init_program()
def disable(self):
"""
Exit from lazy mode.
NOTE(dev): This is a very low level API and not exposed for user.
"""
if not self._state:
return
self._state = False
def _init_program(self):
self.startup_program = framework.Program()
def __enter__(self):
"""
Switch into lazy mode and set _dygraph_tracer_ with None to convert
dygraph mode into static mode.
"""
self.enable(clear_cache=True)
if self._in_guard: return
self._tracer = framework._dygraph_tracer_
framework._dygraph_tracer_ = None
self._in_guard = True
def __exit__(self, *args, **kwargs):
"""
Exit from lazy mode and recover _dygraph_tracer_.
"""
self.disable()
if not self._in_guard: return
assert self._tracer is not None
framework._dygraph_tracer_ = self._tracer
self._tracer = None
self._in_guard = False
@property
def state(self):
return self._state
_lazy_guard = LazyGuard()
def lazy_guard():
global _lazy_guard
return _lazy_guard
class LazyInit(object):
"""
LazyInit is a wrapper interface for nn.Layer, it forwards the construct
process of user defined Layer. Meanwhile, it provides necessary API to
trigger EagerParamBase Lazy Initialization and get startup Program.
"""
def __init__(self, class_obj=None):
self.class_obj = class_obj
self.clear_cache = True
def __call__(self, *args, **kwargs):
"""
Construct instance from class_obj by Lazy Initializing parameters.
Examples:
.. code-block:: python
from paddle import LazyInit
from paddle.nn import Linear
fc = LazyInit(Linear)(10, 10)
for param in fc.parameters():
param.initialize()
"""
assert isinstance(
self.class_obj, type
), "Required class_obj must be a class type, but received %s." % self.class_obj
global _lazy_guard
_lazy_guard.enable(self.clear_cache)
# construct Layer instance
with framework.program_guard(framework.Program()):
instance = self.class_obj(*args, **kwargs)
_lazy_guard.disable()
# set @property dynamically to visit startup_program
instance.startup_program = _lazy_guard.startup_program
return instance
@staticmethod
def startup_program():
"""
A static method to get startup program for the latest Layer.
Examples:
.. code-block:: python
from paddle import LazyInit
from paddle.nn import Linear
fc = LazyInit(Linear)(10, 10)
print(LazyInit.startup_program())
"""
return _lazy_guard.startup_program
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from paddle import LazyInit
from paddle.nn import Linear
from paddle.nn.initializer import *
from paddle.fluid import unique_name
class TestInitializerBase(unittest.TestCase):
def setUp(self):
self.set_initializer()
self.set_param_attr()
self.set_init_ops()
self.clear_nameset()
def set_initializer(self):
self.w_initializer = Constant(0.6)
self.b_initializer = Constant(0.3)
def set_param_attr(self):
self.weight_attr = paddle.ParamAttr(name="weight",
initializer=self.w_initializer)
self.bias_attr = paddle.ParamAttr(name="bias",
initializer=self.b_initializer)
def set_init_ops(self):
self.init_ops = ['fill_constant', 'fill_constant']
def clear_nameset(self):
unique_name.dygraph_parameter_name_checker._name_set = set()
def test_wrapper(self):
fc = LazyInit(Linear)(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
program = fc.startup_program
self.check_program(program)
def check_program(self, program):
self.assertEqual(program.block(0).var("weight").shape, (10, 10))
self.assertEqual(program.block(0).var("bias").shape, (10, ))
ops = [op.type for op in program.block(0).ops]
self.assertEqual(ops, self.init_ops)
class TestDygraphLazy(TestInitializerBase):
def test_wrapper(self):
fc = LazyInit(Linear)(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
self.check_data(fc)
def check_data(self, model):
x = paddle.randn([2, 10])
# weight and bias have no memory
with self.assertRaises(RuntimeError):
out = model(x)
for param in model.parameters():
param.initialize()
out = model(x)
self.assertEqual(out.shape, [2, 10])
np.testing.assert_allclose(model.weight.numpy(),
np.ones([10, 10], dtype=np.float32) * 0.6)
np.testing.assert_allclose(model.bias.numpy(),
np.ones([10], dtype=np.float32) * 0.3)
class TestUniform(TestInitializerBase):
def set_initializer(self):
self.w_initializer = Uniform()
self.b_initializer = Uniform()
def set_init_ops(self):
self.init_ops = ['uniform_random', 'uniform_random']
class TestNormal(TestInitializerBase):
def set_initializer(self):
self.w_initializer = Normal()
self.b_initializer = Normal()
def set_init_ops(self):
self.init_ops = ['gaussian_random', 'gaussian_random']
class TestTruncatedNormal(TestInitializerBase):
def set_initializer(self):
self.w_initializer = TruncatedNormal()
self.b_initializer = TruncatedNormal()
def set_init_ops(self):
self.init_ops = [
'truncated_gaussian_random', 'truncated_gaussian_random'
]
class TestXavierNormal(TestNormal):
def set_initializer(self):
self.w_initializer = XavierNormal()
self.b_initializer = XavierNormal()
class TestXavierUniform(TestUniform):
def set_initializer(self):
self.w_initializer = XavierUniform()
self.b_initializer = XavierUniform()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册