framework.py 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20
from contextlib import contextmanager

import numpy as np

from paddle.fluid.framework import _dygraph_tracer

21
# TODO: define framework api
22 23
from paddle.fluid.layer_helper_base import LayerHelperBase

24 25
__all__ = []

26 27 28

def set_default_dtype(d):
    """
29
    Set default dtype. The default dtype is initially float32.
30 31

    Args:
32 33
        d(string|np.dtype): the dtype to make the default. It only
                            supports float16, float32 and float64.
34 35 36 37 38 39 40 41 42 43 44

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_default_dtype("float32")

    """
45
    if isinstance(d, type):
46
        # This branch is for NumPy scalar types
47 48 49 50 51
        if d in [np.float16, np.float32, np.float64]:
            d = d.__name__
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
52 53
                ", but received %s" % d.__name__
            )
54
    else:
55 56 57 58 59
        # This branch is for np.dtype and str
        if d in ['float16', 'float32', 'float64']:
            # NOTE(SigureMo): Since the np.dtype object is not an instance of
            # type, so it will not be handled by the previous branch. We need
            # to convert it to str here.
60 61 62 63
            d = str(d)
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
64 65
                ", but received %s" % str(d)
            )
66

67 68 69 70 71
    LayerHelperBase.set_default_dtype(d)


def get_default_dtype():
    """
72
    Get the current default dtype. The default dtype is initially float32.
73 74 75 76

    Args:
        None.
    Returns:
77
        String, this global dtype only supports float16, float32, float64.
78 79 80 81 82 83 84 85

    Examples:
        .. code-block:: python

            import paddle
            paddle.get_default_dtype()
    """
    return LayerHelperBase.get_default_dtype()
86 87 88 89 90 91 92 93 94 95


@contextmanager
def set_grad_enabled(mode):
    """
    Create a context which enables or disables dygraph gradient calculation.

    Args:
        mode(bool): whether to enable (`True`), or disable (`False`) grad.

Z
Zman 已提交
96 97 98
    Returns:
        None.

99 100
    Examples:
        .. code-block:: python
101

W
Wenyu 已提交
102
            import paddle
103 104
            x = paddle.ones([3, 2])
            x.stop_gradient = False
W
Wenyu 已提交
105
            with paddle.set_grad_enabled(False):
106
                y = x * 2
W
Wenyu 已提交
107
                with paddle.set_grad_enabled(True):
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
                    z = x * 2
            print(y.stop_gradient)   # True
            print(z.stop_gradient)   # False
    """

    tracer = _dygraph_tracer()
    if tracer:
        prev_mode = tracer._has_grad
        tracer._has_grad = mode
        try:
            yield
        finally:
            tracer._has_grad = prev_mode
    else:
        yield
W
wuhuanzhou 已提交
123 124 125 126 127 128 129 130 131 132 133


def is_grad_enabled():
    """
    Returns whether current dygraph gradient calculation mode is enabled.

    Returns:
        bool: True if current dygraph gradient calculation mode is enabled, otherwise false.

    Examples:
        .. code-block:: python
134

W
wuhuanzhou 已提交
135
            import paddle
136

W
wuhuanzhou 已提交
137 138 139 140 141 142 143 144 145 146 147
            # Dygraph gradient calculation mode is enabled by default.
            paddle.is_grad_enabled() # True

            with paddle.set_grad_enabled(False):
                paddle.is_grad_enabled() # False

            paddle.enable_static()
            paddle.is_grad_enabled() # False
    """
    tracer = _dygraph_tracer()
    return tracer._has_grad if tracer else False