framework.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define framework api
16 17
from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.data_feeder import convert_dtype
18
from paddle.fluid.framework import _dygraph_tracer
19
import numpy as np
20
from contextlib import contextmanager
21

22 23
__all__ = []

24 25 26

def set_default_dtype(d):
    """
27
    Set default dtype. The default dtype is initially float32.
28 29

    Args:
30 31
        d(string|np.dtype): the dtype to make the default. It only
                            supports float16, float32 and float64.
32 33 34 35 36 37 38 39 40 41 42

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_default_dtype("float32")

    """
43 44 45 46 47 48
    if isinstance(d, type):
        if d in [np.float16, np.float32, np.float64]:
            d = d.__name__
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
L
Ligoml 已提交
49 50
                ", but received %s" % d.__name__
            )
51 52
    else:
        if d in [
L
Ligoml 已提交
53 54 55 56 57 58
            'float16',
            'float32',
            'float64',
            u'float16',
            u'float32',
            u'float64',
59 60 61 62 63 64 65 66 67 68
        ]:
            # this code is a little bit dangerous, since error could happen
            # when casting no-ascii code to str in python2.
            # but since the set itself is limited, so currently, it is good.
            # however, jointly supporting python2 and python3, (as well as python4 maybe)
            # may still be a long-lasting problem.
            d = str(d)
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
L
Ligoml 已提交
69 70
                ", but received %s" % str(d)
            )
71

72 73 74 75 76
    LayerHelperBase.set_default_dtype(d)


def get_default_dtype():
    """
77
    Get the current default dtype. The default dtype is initially float32.
78 79 80 81 82 83 84 85 86 87 88 89 90

    Args:
        None.
    Returns:
        The default dtype.

    Examples:
        .. code-block:: python

            import paddle
            paddle.get_default_dtype()
    """
    return LayerHelperBase.get_default_dtype()
91 92 93 94 95 96 97 98 99 100 101 102


@contextmanager
def set_grad_enabled(mode):
    """
    Create a context which enables or disables dygraph gradient calculation.

    Args:
        mode(bool): whether to enable (`True`), or disable (`False`) grad.

    Examples:
        .. code-block:: python
L
Ligoml 已提交
103

W
Wenyu 已提交
104
            import paddle
105 106
            x = paddle.ones([3, 2])
            x.stop_gradient = False
W
Wenyu 已提交
107
            with paddle.set_grad_enabled(False):
108
                y = x * 2
W
Wenyu 已提交
109
                with paddle.set_grad_enabled(True):
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                    z = x * 2
            print(y.stop_gradient)   # True
            print(z.stop_gradient)   # False
    """

    tracer = _dygraph_tracer()
    if tracer:
        prev_mode = tracer._has_grad
        tracer._has_grad = mode
        try:
            yield
        finally:
            tracer._has_grad = prev_mode
    else:
        yield
W
wuhuanzhou 已提交
125 126 127 128 129 130 131 132 133 134 135


def is_grad_enabled():
    """
    Returns whether current dygraph gradient calculation mode is enabled.

    Returns:
        bool: True if current dygraph gradient calculation mode is enabled, otherwise false.

    Examples:
        .. code-block:: python
L
Ligoml 已提交
136

W
wuhuanzhou 已提交
137
            import paddle
L
Ligoml 已提交
138

W
wuhuanzhou 已提交
139 140 141 142 143 144 145 146 147 148 149
            # Dygraph gradient calculation mode is enabled by default.
            paddle.is_grad_enabled() # True

            with paddle.set_grad_enabled(False):
                paddle.is_grad_enabled() # False

            paddle.enable_static()
            paddle.is_grad_enabled() # False
    """
    tracer = _dygraph_tracer()
    return tracer._has_grad if tracer else False