diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index fafe08a81f24af21f3e22bb2d080cbe71d4dfeeb..6f8229e6f18f52058bf957d55d31ddabe73ebfdd 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -257,6 +257,7 @@ from .framework import CUDAPinnedPlace #DEFINE_ALIAS from .framework import grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS +from .framework import set_grad_enabled #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index e6e7b8222a4b30c066cb9dc07a98c5570e6a0a8a..9dae36c3c223f89617fdad0fe8c4e42daa0a2613 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -296,6 +296,28 @@ class TestImperative(unittest.TestCase): self.assertTrue(tmp._grad_ivar() is None) self.assertTrue(l0.weight._grad_ivar() is not None) + def test_paddle_imperative_set_grad_enabled(self): + data = np.array([[2, 3], [4, 5]]).astype('float32') + with fluid.dygraph.guard(): + l0 = fluid.Linear(2, 2) + self.assertTrue(l0.weight._grad_ivar() is None) + l1 = fluid.Linear(2, 2) + with paddle.set_grad_enabled(False): + self.assertTrue(l1.weight.stop_gradient is False) + tmp = l1.weight * 2 + with paddle.set_grad_enabled(True): + tmp2 = l1.weight * 2 + self.assertTrue(tmp.stop_gradient) + self.assertTrue(tmp2.stop_gradient is False) + x = fluid.dygraph.to_variable(data) + y = l0(x) + tmp2 + o = l1(y) + o.backward() + + self.assertTrue(tmp._grad_ivar() is None) + self.assertTrue(tmp2._grad_ivar() is not None) + self.assertTrue(l0.weight._grad_ivar() is not None) + def test_sum_op(self): x = np.ones([2, 2], np.float32) with fluid.dygraph.guard(): diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 5a616d81659b2361266542ca9e6f5ce40503506d..b8684874085a9c44048d1473ddec956dcff3cf3e 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -18,12 +18,16 @@ __all__ = [ 'NPUPlace', 'get_default_dtype', 'set_default_dtype' ] -__all__ += ['grad', 'LayerList', 'load', 'save', 'no_grad', 'DataParallel'] +__all__ += [ + 'grad', 'set_grad_enabled', 'LayerList', 'load', 'save', 'no_grad', + 'DataParallel' +] from . import random from .random import seed from .framework import get_default_dtype from .framework import set_default_dtype +from .framework import set_grad_enabled from ..fluid.param_attr import ParamAttr #DEFINE_ALIAS # from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index 41ec18ce32d3036c3db86aaa98053f59ff61f717..77be85a3195fd6dbac6f07c57e867f01343dfc2a 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -15,7 +15,9 @@ # TODO: define framework api from paddle.fluid.layer_helper_base import LayerHelperBase from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.framework import _dygraph_tracer import numpy as np +from contextlib import contextmanager __all__ = ['set_default_dtype', 'get_default_dtype'] @@ -80,3 +82,37 @@ def get_default_dtype(): paddle.get_default_dtype() """ return LayerHelperBase.get_default_dtype() + + +@contextmanager +def set_grad_enabled(mode): + """ + :api_attr: imperative + + Create a context which enables or disables dygraph gradient calculation. + + Args: + mode(bool): whether to enable (`True`), or disable (`False`) grad. + + Examples: + .. code-block:: python + x = paddle.ones([3, 2]) + x.stop_gradient = False + with torch.set_grad_enabled(False): + y = x * 2 + with torch.set_grad_enabled(True): + z = x * 2 + print(y.stop_gradient) # True + print(z.stop_gradient) # False + """ + + tracer = _dygraph_tracer() + if tracer: + prev_mode = tracer._has_grad + tracer._has_grad = mode + try: + yield + finally: + tracer._has_grad = prev_mode + else: + yield