lazy_init.py 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import framework

17
__all__ = ["LazyGuard"]
18 19


20
class LazyInitHelper:
21
    """
22
    A Helper Context to trigger switching mode between dygraph and static mode,
23 24 25 26 27 28 29 30
    and holds the startup program resource.
    """

    def __init__(self):
        self._state = False
        self._tracer = None
        self._in_guard = False

31
    def enable(self):
32 33 34 35 36 37 38
        """
        Switch into lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if self._state:
            return
39 40
        assert (
            framework._non_static_mode()
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        ), "LazyInit.enable() is only available in dygraph mode."
        self._state = True

    def disable(self):
        """
        Exit from lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if not self._state:
            return
        self._state = False

    def __enter__(self):
        """
        Switch into lazy mode and set _dygraph_tracer_ with None to convert
        dygraph mode into static mode.
        """
59
        self.enable()
60 61
        if self._in_guard:
            return
62 63 64 65 66 67 68 69 70
        self._tracer = framework._dygraph_tracer_
        framework._dygraph_tracer_ = None
        self._in_guard = True

    def __exit__(self, *args, **kwargs):
        """
        Exit from lazy mode and recover _dygraph_tracer_.
        """
        self.disable()
71 72
        if not self._in_guard:
            return
73 74 75 76 77 78 79 80 81 82
        assert self._tracer is not None
        framework._dygraph_tracer_ = self._tracer
        self._tracer = None
        self._in_guard = False

    @property
    def state(self):
        return self._state


83
_lazy_init_helper = LazyInitHelper()
84 85


86 87 88
def lazy_init_helper():
    global _lazy_init_helper
    return _lazy_init_helper
89 90


91
class LazyGuard:
92
    """
93
    LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
94 95 96 97
    process of user defined Layer. Meanwhile, it provides necessary API to
    trigger EagerParamBase Lazy Initialization and get startup Program.
    """

98
    def __enter__(self):
99 100 101 102
        """
        Construct instance from class_obj by Lazy Initializing parameters.

        Examples:
103

104 105
            .. code-block:: python

106
                from paddle import LazyGuard
107
                from paddle.nn import Linear
108

109 110
                with LazyGuard():
                    fc = LazyInit(Linear)(10, 10)
111 112 113 114

                for param in fc.parameters():
                    param.initialize()
        """
115
        lazy_init_helper().enable()
116

117 118
    def __exit__(self, *args, **kwargs):
        lazy_init_helper().disable()