未验证 提交 f8ca5a9d 编写于 作者: Y Yang Zhang 提交者: GitHub

Add `paddle.set_grad_enabled` (#31794)

上级 c3328288
...@@ -257,6 +257,7 @@ from .framework import CUDAPinnedPlace #DEFINE_ALIAS ...@@ -257,6 +257,7 @@ from .framework import CUDAPinnedPlace #DEFINE_ALIAS
from .framework import grad #DEFINE_ALIAS from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS
from .framework import set_grad_enabled #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS
......
...@@ -296,6 +296,28 @@ class TestImperative(unittest.TestCase): ...@@ -296,6 +296,28 @@ class TestImperative(unittest.TestCase):
self.assertTrue(tmp._grad_ivar() is None) self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(l0.weight._grad_ivar() is not None) self.assertTrue(l0.weight._grad_ivar() is not None)
def test_paddle_imperative_set_grad_enabled(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2)
self.assertTrue(l0.weight._grad_ivar() is None)
l1 = fluid.Linear(2, 2)
with paddle.set_grad_enabled(False):
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
with paddle.set_grad_enabled(True):
tmp2 = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
self.assertTrue(tmp2.stop_gradient is False)
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp2
o = l1(y)
o.backward()
self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None)
def test_sum_op(self): def test_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -18,12 +18,16 @@ __all__ = [ ...@@ -18,12 +18,16 @@ __all__ = [
'NPUPlace', 'get_default_dtype', 'set_default_dtype' 'NPUPlace', 'get_default_dtype', 'set_default_dtype'
] ]
__all__ += ['grad', 'LayerList', 'load', 'save', 'no_grad', 'DataParallel'] __all__ += [
'grad', 'set_grad_enabled', 'LayerList', 'load', 'save', 'no_grad',
'DataParallel'
]
from . import random from . import random
from .random import seed from .random import seed
from .framework import get_default_dtype from .framework import get_default_dtype
from .framework import set_default_dtype from .framework import set_default_dtype
from .framework import set_grad_enabled
from ..fluid.param_attr import ParamAttr #DEFINE_ALIAS from ..fluid.param_attr import ParamAttr #DEFINE_ALIAS
# from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS # from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
# TODO: define framework api # TODO: define framework api
from paddle.fluid.layer_helper_base import LayerHelperBase from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.framework import _dygraph_tracer
import numpy as np import numpy as np
from contextlib import contextmanager
__all__ = ['set_default_dtype', 'get_default_dtype'] __all__ = ['set_default_dtype', 'get_default_dtype']
...@@ -80,3 +82,37 @@ def get_default_dtype(): ...@@ -80,3 +82,37 @@ def get_default_dtype():
paddle.get_default_dtype() paddle.get_default_dtype()
""" """
return LayerHelperBase.get_default_dtype() return LayerHelperBase.get_default_dtype()
@contextmanager
def set_grad_enabled(mode):
"""
:api_attr: imperative
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Examples:
.. code-block:: python
x = paddle.ones([3, 2])
x.stop_gradient = False
with torch.set_grad_enabled(False):
y = x * 2
with torch.set_grad_enabled(True):
z = x * 2
print(y.stop_gradient) # True
print(z.stop_gradient) # False
"""
tracer = _dygraph_tracer()
if tracer:
prev_mode = tracer._has_grad
tracer._has_grad = mode
try:
yield
finally:
tracer._has_grad = prev_mode
else:
yield
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册