framework.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define framework api 
16 17
from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.data_feeder import convert_dtype
18
from paddle.fluid.framework import _dygraph_tracer
19
import numpy as np
20
from contextlib import contextmanager
21

22 23
__all__ = []

24 25 26

def set_default_dtype(d):
    """
27
    Set default dtype. The default dtype is initially float32.
28 29

    Args:
30 31
        d(string|np.dtype): the dtype to make the default. It only
                            supports float16, float32 and float64.
32 33 34 35 36 37 38 39 40 41 42

    Returns:
        None.

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_default_dtype("float32")

    """
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    if isinstance(d, type):
        if d in [np.float16, np.float32, np.float64]:
            d = d.__name__
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
                ", but received %s" % d.__name__)
    else:
        if d in [
                'float16', 'float32', 'float64', u'float16', u'float32',
                u'float64'
        ]:
            # this code is a little bit dangerous, since error could happen
            # when casting no-ascii code to str in python2.
            # but since the set itself is limited, so currently, it is good.
            # however, jointly supporting python2 and python3, (as well as python4 maybe)
            # may still be a long-lasting problem.
            d = str(d)
        else:
            raise TypeError(
                "set_default_dtype only supports [float16, float32, float64] "
                ", but received %s" % str(d))

66 67 68 69 70
    LayerHelperBase.set_default_dtype(d)


def get_default_dtype():
    """
71
    Get the current default dtype. The default dtype is initially float32.
72 73 74 75 76 77 78 79 80 81 82 83 84

    Args:
        None.
    Returns:
        The default dtype.

    Examples:
        .. code-block:: python

            import paddle
            paddle.get_default_dtype()
    """
    return LayerHelperBase.get_default_dtype()
85 86 87 88 89 90 91 92 93 94 95 96


@contextmanager
def set_grad_enabled(mode):
    """
    Create a context which enables or disables dygraph gradient calculation.

    Args:
        mode(bool): whether to enable (`True`), or disable (`False`) grad.

    Examples:
        .. code-block:: python
C
Chen Long 已提交
97
            
W
Wenyu 已提交
98
            import paddle
99 100
            x = paddle.ones([3, 2])
            x.stop_gradient = False
W
Wenyu 已提交
101
            with paddle.set_grad_enabled(False):
102
                y = x * 2
W
Wenyu 已提交
103
                with paddle.set_grad_enabled(True):
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
                    z = x * 2
            print(y.stop_gradient)   # True
            print(z.stop_gradient)   # False
    """

    tracer = _dygraph_tracer()
    if tracer:
        prev_mode = tracer._has_grad
        tracer._has_grad = mode
        try:
            yield
        finally:
            tracer._has_grad = prev_mode
    else:
        yield
W
wuhuanzhou 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143


def is_grad_enabled():
    """
    Returns whether current dygraph gradient calculation mode is enabled.

    Returns:
        bool: True if current dygraph gradient calculation mode is enabled, otherwise false.

    Examples:
        .. code-block:: python
            
            import paddle
            
            # Dygraph gradient calculation mode is enabled by default.
            paddle.is_grad_enabled() # True

            with paddle.set_grad_enabled(False):
                paddle.is_grad_enabled() # False

            paddle.enable_static()
            paddle.is_grad_enabled() # False
    """
    tracer = _dygraph_tracer()
    return tracer._has_grad if tracer else False