lazy_init.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

D
Difer 已提交
15
from ...fluid import framework
16

17
__all__ = ["LazyGuard"]
18 19


20
class LazyInitHelper:
21
    """
22
    A Helper Context to trigger switching mode between dygraph and static graph mode,
23 24 25 26 27 28 29 30
    and holds the startup program resource.
    """

    def __init__(self):
        self._state = False
        self._tracer = None
        self._in_guard = False

31
    def enable(self):
32 33 34 35 36 37 38
        """
        Switch into lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if self._state:
            return
39
        assert (
40
            framework.in_dygraph_mode()
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        ), "LazyInit.enable() is only available in dygraph mode."
        self._state = True

    def disable(self):
        """
        Exit from lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if not self._state:
            return
        self._state = False

    def __enter__(self):
        """
        Switch into lazy mode and set _dygraph_tracer_ with None to convert
57
        dygraph mode into static graph mode.
58
        """
59
        self.enable()
60 61
        if self._in_guard:
            return
62 63
        self._tracer = framework.global_var._dygraph_tracer_
        framework.global_var._dygraph_tracer_ = None
64 65 66 67 68 69 70
        self._in_guard = True

    def __exit__(self, *args, **kwargs):
        """
        Exit from lazy mode and recover _dygraph_tracer_.
        """
        self.disable()
71 72
        if not self._in_guard:
            return
73
        assert self._tracer is not None
74
        framework.global_var._dygraph_tracer_ = self._tracer
75 76 77 78 79 80 81 82
        self._tracer = None
        self._in_guard = False

    @property
    def state(self):
        return self._state


83
_lazy_init_helper = LazyInitHelper()
84 85


86 87 88
def lazy_init_helper():
    global _lazy_init_helper
    return _lazy_init_helper
89 90


91
class LazyGuard:
92
    """
93
    LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
94 95
    process of user defined Layer. Meanwhile, it provides necessary API to
    trigger EagerParamBase Lazy Initialization and get startup Program.
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

    Examples:

        .. code-block:: python

            from paddle import LazyGuard
            from paddle.nn import Linear

            with LazyGuard():
                # w and b are initialized lazily and have no memory.
                net = Linear(10, 10)

            for param in net.parameters():
                # Initialize param and allocate memory explicitly.
                param.initialize()
111 112
    """

113
    def __enter__(self):
114 115 116 117
        """
        Construct instance from class_obj by Lazy Initializing parameters.

        Examples:
118

119 120
            .. code-block:: python

121
                from paddle import LazyGuard
122
                from paddle.nn import Linear
123

124 125
                with LazyGuard():
                    fc = LazyInit(Linear)(10, 10)
126 127 128 129

                for param in fc.parameters():
                    param.initialize()
        """
130
        lazy_init_helper().enable()
131

132 133
    def __exit__(self, *args, **kwargs):
        lazy_init_helper().disable()