lazy_init.py 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import framework

17
__all__ = ["LazyGuard"]
18 19


20
class LazyInitHelper(object):
21
    """
22
    A Helper Context to trigger switching mode between dygraph and static mode,
23 24 25 26 27 28 29 30
    and holds the startup program resource.
    """

    def __init__(self):
        self._state = False
        self._tracer = None
        self._in_guard = False

31
    def enable(self):
32 33 34 35 36 37 38
        """
        Switch into lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if self._state:
            return
39
        assert framework._non_static_mode(
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        ), "LazyInit.enable() is only available in dygraph mode."
        self._state = True

    def disable(self):
        """
        Exit from lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if not self._state:
            return
        self._state = False

    def __enter__(self):
        """
        Switch into lazy mode and set _dygraph_tracer_ with None to convert
        dygraph mode into static mode.
        """
58
        self.enable()
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        if self._in_guard: return
        self._tracer = framework._dygraph_tracer_
        framework._dygraph_tracer_ = None
        self._in_guard = True

    def __exit__(self, *args, **kwargs):
        """
        Exit from lazy mode and recover _dygraph_tracer_.
        """
        self.disable()
        if not self._in_guard: return
        assert self._tracer is not None
        framework._dygraph_tracer_ = self._tracer
        self._tracer = None
        self._in_guard = False

    @property
    def state(self):
        return self._state


80
_lazy_init_helper = LazyInitHelper()
81 82


83 84 85
def lazy_init_helper():
    global _lazy_init_helper
    return _lazy_init_helper
86 87


88
class LazyGuard(object):
89
    """
90
    LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
91 92 93 94
    process of user defined Layer. Meanwhile, it provides necessary API to
    trigger EagerParamBase Lazy Initialization and get startup Program.
    """

95
    def __enter__(self):
96 97 98 99
        """
        Construct instance from class_obj by Lazy Initializing parameters.

        Examples:
100

101 102
            .. code-block:: python

103
                from paddle import LazyGuard
104
                from paddle.nn import Linear
105

106 107
                with LazyGuard():
                    fc = LazyInit(Linear)(10, 10)
108 109 110 111

                for param in fc.parameters():
                    param.initialize()
        """
112
        lazy_init_helper().enable()
113

114 115
    def __exit__(self, *args, **kwargs):
        lazy_init_helper().disable()