base.py 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
16 17 18
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
M
minqiyang 已提交
19
from .tracer import Tracer
Z
Zeng Jinle 已提交
20
import logging
J
Jiabin Yang 已提交
21
import objgraph
22

23 24 25 26 27
__all__ = [
    'no_grad',
    'guard',
    'to_variable',
]
28 29


30 31 32 33 34 35 36 37 38 39 40
def _switch_to_static_graph_(func):
    def __impl__(*args, **kwargs):
        with framework._dygraph_guard(None):
            return func(*args, **kwargs)

    return __impl__


switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)


41 42 43 44 45 46 47 48 49 50 51
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
    tracer = framework._dygraph_tracer()
    if tracer:
        original_val = tracer._enable_program_desc_tracing
        tracer._enable_program_desc_tracing = enable
    yield
    if tracer:
        tracer._enable_program_desc_tracing = original_val


L
lujun 已提交
52
# This function should be removed in V1.6, because it can easily lead to cyclic dependencies.
53
def enabled():
L
lujun 已提交
54
    # Internal use only
L
lujun 已提交
55
    return framework.in_dygraph_mode()
56 57


58 59 60 61 62 63 64 65 66 67 68 69 70
@contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True):
    tracer = framework._dygraph_tracer()
    if tracer:
        mode = tracer._train_mode
        tracer._train_mode = is_train
        yield
        tracer._train_mode = mode
    else:
        yield


def _no_grad_(func):
71 72 73
    """
    This Decorator will avoid the func being decorated creating backward network in dygraph mode

74 75
    Parameter:
        - **func** (python func): the func don't need grad
76 77 78 79 80 81 82 83 84 85 86

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        @fluid.dygraph.no_grad
        def test_layer():
            with fluid.dygraph.guard():
87
                inp = np.ones([3, 1024], dtype='float32')
88
                t = fluid.dygraph.base.to_variable(inp)
89 90 91 92
                linear1 = fluid.Linear(1024, 4, bias_attr=False)
                linear2 = fluid.Linear(4, 4)
                ret = linear1(t)
                dy_ret = linear2(ret)
93 94 95 96 97

        test_layer()

    """

98 99 100 101 102 103 104 105
    def __impl__(*args, **kwargs):
        with _switch_tracer_mode_guard_(is_train=False):
            return func(*args, **kwargs)

    return __impl__


no_grad = wrap_decorator(_no_grad_)
L
lujun 已提交
106 107
# for fluidDoc
no_grad.__doc__ = _no_grad_.__doc__
108 109


S
rename  
sneaxiy 已提交
110
@signature_safe_contextmanager
P
Paddle CI 已提交
111
def guard(place=None):
112
    """
113
    This context will create a dygraph context for dygraph to run, using python ``with`` statement.
114

115 116 117
    Parameters:
        place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph. 
            If None, the running place will be determined according to the way of paddle compilation. Default: None
118 119 120 121 122 123 124 125 126 127 128 129

    return:
        None

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        with fluid.dygraph.guard():
130
            inp = np.ones([3, 1024], dtype='float32')
131
            t = fluid.dygraph.base.to_variable(inp)
132 133 134 135
            linear1 = fluid.Linear(1024, 4, bias_attr=False)
            linear2 = fluid.Linear(4, 4)
            ret = linear1(t)
            dy_ret = linear2(ret)
136 137

    """
138 139
    train = framework.Program()
    startup = framework.Program()
J
Jiabin Yang 已提交
140
    tracer = Tracer()
141
    VarBase = core.VarBase
M
minqiyang 已提交
142

P
Paddle CI 已提交
143
    if place is None:
M
minqiyang 已提交
144
        if core.is_compiled_with_cuda():
P
Paddle CI 已提交
145
            place = core.CUDAPlace(0)
M
minqiyang 已提交
146 147
        else:
            place = core.CPUPlace()
148
    tracer._expected_place = place
M
minqiyang 已提交
149

150 151
    with framework.program_guard(train, startup):
        with framework.unique_name.guard():
L
lujun 已提交
152 153
            with framework._dygraph_guard(tracer):
                with framework._dygraph_place_guard(place):
P
Paddle CI 已提交
154
                    yield
155 156


157
def _print_debug_msg(parameter_list, limit=5, is_test=False):
Z
Zeng Jinle 已提交
158 159 160 161 162 163
    if not core._is_dygraph_debug_enabled():
        logging.warn(
            'Debug mode is not enabled. Please set FLAGS_dygraph_debug=1 to enable debug'
        )
        return
    unique_name_size = len(framework.unique_name.generator.ids)
164
    tracer_var_size = len(parameter_list)
Z
Zeng Jinle 已提交
165
    alive_cpp_var_size = len(core.VarBase._alive_vars())
J
Jiabin Yang 已提交
166 167 168 169 170 171 172
    if not is_test:
        logging.warn(
            'unique_name num: {}, tracer vars num: {}, alive cpp vars num: {}'
            .format(unique_name_size, tracer_var_size, alive_cpp_var_size))
        objgraph.show_growth(limit=limit)
    else:
        return unique_name_size, tracer_var_size, alive_cpp_var_size
Z
Zeng Jinle 已提交
173 174


175
@framework.dygraph_only
176
def to_variable(value, name=None, zero_copy=None):
177
    """
178
    The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
179

180
    Parameters:
181
        value(ndarray|Variable): The numpy\.ndarray or Variable object that needs to be converted, it can be multi-dimension, and the data type is one of numpy\.{float16, float32, float64, int16, int32, int64, uint8, uint16}.
182
        name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
183
        zero_copy(bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace and will be set to True when it is None. Default: None.
184

185
    Returns:
186 187
        Variable: If ``value`` is a numpy\.ndarray object, return ``Tensor`` created from the specified numpy\.ndarray object, which has same data type and shape with ``value``. If ``value`` is a Variable object, just return ``value``.

188 189 190 191 192 193 194 195

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

196
        with fluid.dygraph.guard(fluid.CPUPlace()):
197
            x = np.ones([2, 2], np.float32)
198 199 200
            y = fluid.dygraph.to_variable(x, zero_copy=False)
            x[0][0] = -1
            y[0][0].numpy()  # array([1.], dtype=float32)
201
            y = fluid.dygraph.to_variable(x)
202 203
            x[0][0] = 0
            y[0][0].numpy()  # array([0.], dtype=float32)
204 205

    """
206
    if isinstance(value, np.ndarray):
L
lujun 已提交
207 208
        assert framework.in_dygraph_mode(
        ), "to_variable could only be called in dygraph mode"
209 210 211 212 213 214
        if isinstance(framework._current_expected_place(),
                      framework.core.CPUPlace):
            if zero_copy is None:
                zero_copy = True
        else:
            assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
215 216 217 218
            zero_copy = False
        py_var = core.VarBase(
            value=value,
            place=framework._current_expected_place(),
L
Leo Chen 已提交
219 220 221
            persistable=False,
            zero_copy=zero_copy,
            name=name if name else '')
222
        return py_var
223
    elif isinstance(value, (core.VarBase, framework.Variable)):
224
        return value
225 226 227
    else:
        raise TypeError(
            "to_variable only accepts 'ndarray' and 'Variable' as value's input")