diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 05892a815f6367cccc518faf684363248dfa7312..6b99a91cdb307464e9d0dbfd3461aaaef2137f53 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -36,6 +36,7 @@ from .framework import disable_static # noqa: F401 from .framework import enable_static # noqa: F401 from .framework import in_dynamic_mode # noqa: F401 from .fluid.dataset import * # noqa: F401 +from .fluid.lazy_init import LazyInit # noqa: F401 from .framework.dtype import dtype as dtype # noqa: F401 from .framework.dtype import uint8 # noqa: F401 @@ -415,6 +416,7 @@ __all__ = [ # noqa 'cumprod', 'logcumsumexp', 'logit', + 'LazyInit', 'sign', 'is_empty', 'equal', diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 5b0aba7a9dabba8bb615bd8a8198be9518e595e0..f70bfbde1e80a5076d71474b48808851d455c9e8 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -120,6 +120,10 @@ def monkey_patch_varbase(): for attr in attr_keys: attr_kwargs[attr] = getattr(self, attr, None) + # If specify block, use it instead of self.block + if 'block' in kwargs: + attr_kwargs['block'] = kwargs['block'] + attr_kwargs.update(kwargs) if to_parameter or isinstance(self, (ParamBase, EagerParamBase)): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3d7a743376c59475cd2f9c9013197d2ddd9d0481..1261eb898a3368d8533e7764ca096b2b4511fcb3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6795,6 +6795,17 @@ class EagerParamBase(_core_eager_eagertensor): self.is_distributed = kwargs.get('is_distributed', False) # self.block = default_main_program().global_block() + self.init_func = None + + def set_init_func(self, obj): + self.init_func = obj + + @dygraph_only + def initialize(self): + assert self.init_func is not None, "Required self.init_func is not None, but received None." + self.init_func() + # clear function handle to release resource + self.init_func = None @property def trainable(self): diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 6381cb59f20eb3fd182947e73a5a7b19380f0cf8..3ecb5dc5602be38353216d933513804bb2c4bed8 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -15,9 +15,12 @@ from __future__ import print_function import math +import functools from . import framework from . import core from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place +from .lazy_init import lazy_guard +from .framework import program_guard import numpy as np from .core import VarDesc from . import unique_name @@ -49,10 +52,32 @@ class Initializer(object): pass def __call__(self, param, block=None): + if not lazy_guard().state: + return self.forward(param, block) + + return self._lazy_init(param, block) + + def forward(self, param, block=None): """Add corresponding initialization operations to the network """ raise NotImplementedError() + def _lazy_init(self, param, block=None): + # Apply lazy initialization + assert in_dygraph_mode() + new_block = lazy_guard().startup_program.global_block() + new_var = param._to_static_var(True, block=new_block) + + # Record initializer operator + with lazy_guard(): + self.forward(new_var, new_block) + lazy_guard().enable(clear_cache=False) + # Add hook function for initializing param in dygraph mode + func = functools.partial(self.forward, param, block) + param.set_init_func(func) + + return param + def _check_block(self, block): if block is None: block = default_main_program().global_block() @@ -121,7 +146,7 @@ class ConstantInitializer(Initializer): self._value = value self._force_cpu = force_cpu - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with constant. Args: @@ -214,7 +239,7 @@ class UniformInitializer(Initializer): self._diag_step = diag_step self._diag_val = diag_val - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with Uniform distribution. Args: @@ -316,7 +341,7 @@ class NormalInitializer(Initializer): self._std_dev = scale self._seed = seed - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with Normal distribution. Args: @@ -428,7 +453,7 @@ class TruncatedNormalInitializer(Initializer): self._std_dev = scale self._seed = seed - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with TruncatedNormal distribution. Args: @@ -562,7 +587,7 @@ class XavierInitializer(Initializer): self._fan_out = fan_out self._seed = seed - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with Xavier initialization. Args: @@ -740,7 +765,7 @@ class MSRAInitializer(Initializer): self._negative_slope = negative_slope self._nonlinearity = nonlinearity - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with MSRA initialization. Args: @@ -901,7 +926,7 @@ class BilinearInitializer(Initializer): """ super(BilinearInitializer, self).__init__() - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with Bilinear initialization. Args: @@ -1026,7 +1051,7 @@ class NumpyArrayInitializer(Initializer): super(NumpyArrayInitializer, self).__init__() self._value = value - def __call__(self, var, block=None): + def forward(self, var, block=None): """Initialize the input tensor with Numpy array. Args: diff --git a/python/paddle/fluid/lazy_init.py b/python/paddle/fluid/lazy_init.py new file mode 100644 index 0000000000000000000000000000000000000000..1e93c74c35b0cfba8ea7befe9409b9f705d711ef --- /dev/null +++ b/python/paddle/fluid/lazy_init.py @@ -0,0 +1,154 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import framework + +__all__ = ["LazyInit"] + + +class LazyGuard(object): + """ + Guard Context to trigger switching mode between dygraph and static mode, + and holds the startup program resource. + """ + + def __init__(self): + self._init_program() + + self._state = False + self._tracer = None + self._in_guard = False + + def enable(self, clear_cache=True): + """ + Switch into lazy mode. + + NOTE(dev): This is a very low level API and not exposed for user. + """ + if self._state: + return + assert framework.in_dygraph_mode( + ), "LazyInit.enable() is only available in dygraph mode." + self._state = True + + if clear_cache: + self._init_program() + + def disable(self): + """ + Exit from lazy mode. + + NOTE(dev): This is a very low level API and not exposed for user. + """ + if not self._state: + return + self._state = False + + def _init_program(self): + self.startup_program = framework.Program() + + def __enter__(self): + """ + Switch into lazy mode and set _dygraph_tracer_ with None to convert + dygraph mode into static mode. + """ + self.enable(clear_cache=True) + if self._in_guard: return + self._tracer = framework._dygraph_tracer_ + framework._dygraph_tracer_ = None + self._in_guard = True + + def __exit__(self, *args, **kwargs): + """ + Exit from lazy mode and recover _dygraph_tracer_. + """ + self.disable() + if not self._in_guard: return + assert self._tracer is not None + framework._dygraph_tracer_ = self._tracer + self._tracer = None + self._in_guard = False + + @property + def state(self): + return self._state + + +_lazy_guard = LazyGuard() + + +def lazy_guard(): + global _lazy_guard + return _lazy_guard + + +class LazyInit(object): + """ + LazyInit is a wrapper interface for nn.Layer, it forwards the construct + process of user defined Layer. Meanwhile, it provides necessary API to + trigger EagerParamBase Lazy Initialization and get startup Program. + """ + + def __init__(self, class_obj=None): + self.class_obj = class_obj + self.clear_cache = True + + def __call__(self, *args, **kwargs): + """ + Construct instance from class_obj by Lazy Initializing parameters. + + Examples: + + .. code-block:: python + + from paddle import LazyInit + from paddle.nn import Linear + + fc = LazyInit(Linear)(10, 10) + + for param in fc.parameters(): + param.initialize() + """ + assert isinstance( + self.class_obj, type + ), "Required class_obj must be a class type, but received %s." % self.class_obj + global _lazy_guard + _lazy_guard.enable(self.clear_cache) + # construct Layer instance + with framework.program_guard(framework.Program()): + instance = self.class_obj(*args, **kwargs) + _lazy_guard.disable() + # set @property dynamically to visit startup_program + instance.startup_program = _lazy_guard.startup_program + + return instance + + @staticmethod + def startup_program(): + """ + A static method to get startup program for the latest Layer. + + Examples: + + .. code-block:: python + + from paddle import LazyInit + from paddle.nn import Linear + + fc = LazyInit(Linear)(10, 10) + + print(LazyInit.startup_program()) + + """ + return _lazy_guard.startup_program diff --git a/python/paddle/fluid/tests/unittests/test_lazy_init.py b/python/paddle/fluid/tests/unittests/test_lazy_init.py new file mode 100644 index 0000000000000000000000000000000000000000..772ac78a69866a182073559589c1b2aad008d829 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lazy_init.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +from paddle import LazyInit +from paddle.nn import Linear +from paddle.nn.initializer import * +from paddle.fluid import unique_name + + +class TestInitializerBase(unittest.TestCase): + + def setUp(self): + self.set_initializer() + self.set_param_attr() + self.set_init_ops() + self.clear_nameset() + + def set_initializer(self): + self.w_initializer = Constant(0.6) + self.b_initializer = Constant(0.3) + + def set_param_attr(self): + self.weight_attr = paddle.ParamAttr(name="weight", + initializer=self.w_initializer) + + self.bias_attr = paddle.ParamAttr(name="bias", + initializer=self.b_initializer) + + def set_init_ops(self): + self.init_ops = ['fill_constant', 'fill_constant'] + + def clear_nameset(self): + unique_name.dygraph_parameter_name_checker._name_set = set() + + def test_wrapper(self): + fc = LazyInit(Linear)(10, + 10, + weight_attr=self.weight_attr, + bias_attr=self.bias_attr) + program = fc.startup_program + self.check_program(program) + + def check_program(self, program): + self.assertEqual(program.block(0).var("weight").shape, (10, 10)) + self.assertEqual(program.block(0).var("bias").shape, (10, )) + ops = [op.type for op in program.block(0).ops] + self.assertEqual(ops, self.init_ops) + + +class TestDygraphLazy(TestInitializerBase): + + def test_wrapper(self): + fc = LazyInit(Linear)(10, + 10, + weight_attr=self.weight_attr, + bias_attr=self.bias_attr) + + self.check_data(fc) + + def check_data(self, model): + x = paddle.randn([2, 10]) + # weight and bias have no memory + with self.assertRaises(RuntimeError): + out = model(x) + + for param in model.parameters(): + param.initialize() + + out = model(x) + self.assertEqual(out.shape, [2, 10]) + + np.testing.assert_allclose(model.weight.numpy(), + np.ones([10, 10], dtype=np.float32) * 0.6) + np.testing.assert_allclose(model.bias.numpy(), + np.ones([10], dtype=np.float32) * 0.3) + + +class TestUniform(TestInitializerBase): + + def set_initializer(self): + self.w_initializer = Uniform() + self.b_initializer = Uniform() + + def set_init_ops(self): + self.init_ops = ['uniform_random', 'uniform_random'] + + +class TestNormal(TestInitializerBase): + + def set_initializer(self): + self.w_initializer = Normal() + self.b_initializer = Normal() + + def set_init_ops(self): + self.init_ops = ['gaussian_random', 'gaussian_random'] + + +class TestTruncatedNormal(TestInitializerBase): + + def set_initializer(self): + self.w_initializer = TruncatedNormal() + self.b_initializer = TruncatedNormal() + + def set_init_ops(self): + self.init_ops = [ + 'truncated_gaussian_random', 'truncated_gaussian_random' + ] + + +class TestXavierNormal(TestNormal): + + def set_initializer(self): + self.w_initializer = XavierNormal() + self.b_initializer = XavierNormal() + + +class TestXavierUniform(TestUniform): + + def set_initializer(self): + self.w_initializer = XavierUniform() + self.b_initializer = XavierUniform() + + +if __name__ == '__main__': + unittest.main()