未验证 提交 21b7ed3e 编写于 作者: W wuhuanzhou 提交者: GitHub

add control/status API (#37885)

* add control/status API, test=develop

* fix import error, test=develop

* add is_grad_enabled unittest, test=develop

* add code comment for example code and API, test=develop

* add checking for type, test=develop

* add api description, test=develop

* fix docs index_en, test=document_fix

* fix doc of is_floating_point, test=document_fix
上级 745477fe
......@@ -72,6 +72,7 @@ from .tensor.attribute import rank # noqa: F401
from .tensor.attribute import shape # noqa: F401
from .tensor.attribute import real # noqa: F401
from .tensor.attribute import imag # noqa: F401
from .tensor.attribute import is_floating_point # noqa: F401
from .tensor.creation import to_tensor # noqa: F401
from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401
......@@ -285,6 +286,7 @@ from .framework import CUDAPinnedPlace # noqa: F401
from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401
from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
from .framework import load # noqa: F401
from .framework import DataParallel # noqa: F401
......@@ -453,6 +455,7 @@ __all__ = [ # noqa
'shape',
'real',
'imag',
'is_floating_point',
'complex',
'reciprocal',
'rand',
......@@ -468,6 +471,7 @@ __all__ = [ # noqa
'median',
'no_grad',
'set_grad_enabled',
'is_grad_enabled',
'mod',
'abs',
'tril',
......
......@@ -16,7 +16,7 @@ from ..fluid.dygraph.base import grad # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401
......
......@@ -317,6 +317,13 @@ class TestImperative(unittest.TestCase):
self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None)
def test_paddle_imperative_is_grad_enabled(self):
with fluid.dygraph.guard():
with paddle.set_grad_enabled(False):
self.assertTrue(paddle.is_grad_enabled() is False)
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
......
......@@ -19,6 +19,7 @@ from .random import seed # noqa: F401
from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401
from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.layers.tensor import create_parameter # noqa: F401
......
......@@ -116,3 +116,28 @@ def set_grad_enabled(mode):
tracer._has_grad = prev_mode
else:
yield
def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = _dygraph_tracer()
return tracer._has_grad if tracer else False
......@@ -18,6 +18,7 @@ from .attribute import rank # noqa: F401
from .attribute import shape # noqa: F401
from .attribute import real # noqa: F401
from .attribute import imag # noqa: F401
from .attribute import is_floating_point # noqa: F401
from .creation import to_tensor # noqa: F401
from .creation import diag # noqa: F401
from .creation import diagflat # noqa: F401
......@@ -418,6 +419,7 @@ tensor_method_func = [ #noqa
'shape',
'real',
'imag',
'is_floating_point',
'digamma',
'diagonal',
'trunc',
......
......@@ -81,6 +81,30 @@ def is_complex(x):
def is_floating_point(x):
"""
Returns whether the dtype of `x` is one of paddle.float64, paddle.float32, paddle.float16, and paddle.bfloat16.
Args:
x (Tensor): The input tensor.
Returns:
bool: True if the dtype of `x` is floating type, otherwise false.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(1., 5., dtype='float32')
y = paddle.arange(1, 5, dtype='int32')
print(paddle.is_floating_point(x))
# True
print(paddle.is_floating_point(y))
# False
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册