未验证 提交 21b7ed3e 编写于 作者: W wuhuanzhou 提交者: GitHub

add control/status API (#37885)

* add control/status API, test=develop

* fix import error, test=develop

* add is_grad_enabled unittest, test=develop

* add code comment for example code and API, test=develop

* add checking for type, test=develop

* add api description, test=develop

* fix docs index_en, test=document_fix

* fix doc of is_floating_point, test=document_fix
上级 745477fe
...@@ -72,6 +72,7 @@ from .tensor.attribute import rank # noqa: F401 ...@@ -72,6 +72,7 @@ from .tensor.attribute import rank # noqa: F401
from .tensor.attribute import shape # noqa: F401 from .tensor.attribute import shape # noqa: F401
from .tensor.attribute import real # noqa: F401 from .tensor.attribute import real # noqa: F401
from .tensor.attribute import imag # noqa: F401 from .tensor.attribute import imag # noqa: F401
from .tensor.attribute import is_floating_point # noqa: F401
from .tensor.creation import to_tensor # noqa: F401 from .tensor.creation import to_tensor # noqa: F401
from .tensor.creation import diag # noqa: F401 from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401 from .tensor.creation import diagflat # noqa: F401
...@@ -285,6 +286,7 @@ from .framework import CUDAPinnedPlace # noqa: F401 ...@@ -285,6 +286,7 @@ from .framework import CUDAPinnedPlace # noqa: F401
from .autograd import grad # noqa: F401 from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401 from .autograd import no_grad # noqa: F401
from .autograd import set_grad_enabled # noqa: F401 from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401 from .framework import save # noqa: F401
from .framework import load # noqa: F401 from .framework import load # noqa: F401
from .framework import DataParallel # noqa: F401 from .framework import DataParallel # noqa: F401
...@@ -453,6 +455,7 @@ __all__ = [ # noqa ...@@ -453,6 +455,7 @@ __all__ = [ # noqa
'shape', 'shape',
'real', 'real',
'imag', 'imag',
'is_floating_point',
'complex', 'complex',
'reciprocal', 'reciprocal',
'rand', 'rand',
...@@ -468,6 +471,7 @@ __all__ = [ # noqa ...@@ -468,6 +471,7 @@ __all__ = [ # noqa
'median', 'median',
'no_grad', 'no_grad',
'set_grad_enabled', 'set_grad_enabled',
'is_grad_enabled',
'mod', 'mod',
'abs', 'abs',
'tril', 'tril',
......
...@@ -16,7 +16,7 @@ from ..fluid.dygraph.base import grad # noqa: F401 ...@@ -16,7 +16,7 @@ from ..fluid.dygraph.base import grad # noqa: F401
from . import backward_mode # noqa: F401 from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401 from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401 from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401 from .functional import vjp, jvp, vhp # noqa: F401
......
...@@ -317,6 +317,13 @@ class TestImperative(unittest.TestCase): ...@@ -317,6 +317,13 @@ class TestImperative(unittest.TestCase):
self.assertTrue(tmp2._grad_ivar() is not None) self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None) self.assertTrue(l0.weight._grad_ivar() is not None)
def test_paddle_imperative_is_grad_enabled(self):
with fluid.dygraph.guard():
with paddle.set_grad_enabled(False):
self.assertTrue(paddle.is_grad_enabled() is False)
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self): def test_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -19,6 +19,7 @@ from .random import seed # noqa: F401 ...@@ -19,6 +19,7 @@ from .random import seed # noqa: F401
from .framework import get_default_dtype # noqa: F401 from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401 from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401 from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401
from ..fluid.param_attr import ParamAttr # noqa: F401 from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.layers.tensor import create_parameter # noqa: F401 from ..fluid.layers.tensor import create_parameter # noqa: F401
......
...@@ -116,3 +116,28 @@ def set_grad_enabled(mode): ...@@ -116,3 +116,28 @@ def set_grad_enabled(mode):
tracer._has_grad = prev_mode tracer._has_grad = prev_mode
else: else:
yield yield
def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = _dygraph_tracer()
return tracer._has_grad if tracer else False
...@@ -18,6 +18,7 @@ from .attribute import rank # noqa: F401 ...@@ -18,6 +18,7 @@ from .attribute import rank # noqa: F401
from .attribute import shape # noqa: F401 from .attribute import shape # noqa: F401
from .attribute import real # noqa: F401 from .attribute import real # noqa: F401
from .attribute import imag # noqa: F401 from .attribute import imag # noqa: F401
from .attribute import is_floating_point # noqa: F401
from .creation import to_tensor # noqa: F401 from .creation import to_tensor # noqa: F401
from .creation import diag # noqa: F401 from .creation import diag # noqa: F401
from .creation import diagflat # noqa: F401 from .creation import diagflat # noqa: F401
...@@ -418,6 +419,7 @@ tensor_method_func = [ #noqa ...@@ -418,6 +419,7 @@ tensor_method_func = [ #noqa
'shape', 'shape',
'real', 'real',
'imag', 'imag',
'is_floating_point',
'digamma', 'digamma',
'diagonal', 'diagonal',
'trunc', 'trunc',
......
...@@ -81,6 +81,30 @@ def is_complex(x): ...@@ -81,6 +81,30 @@ def is_complex(x):
def is_floating_point(x): def is_floating_point(x):
"""
Returns whether the dtype of `x` is one of paddle.float64, paddle.float32, paddle.float16, and paddle.bfloat16.
Args:
x (Tensor): The input tensor.
Returns:
bool: True if the dtype of `x` is floating type, otherwise false.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(1., 5., dtype='float32')
y = paddle.arange(1, 5, dtype='int32')
print(paddle.is_floating_point(x))
# True
print(paddle.is_floating_point(y))
# False
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or dtype == core.VarDesc.VarType.FP64 or
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册