From 499b7f8741e32df2effe8ea32839ffafc4c79ffc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 22 Feb 2023 13:58:41 +0800 Subject: [PATCH] add enable_grad, reafctor no_grad, set_grad_enabled (#50560) * add enable_grad, reafctor no_grad, set_grad_enabled * fix bug * fix bug * format * fix bug * format * format * fix doc * fix * fix * fix bug * fix comment --- python/paddle/__init__.py | 1 + python/paddle/autograd/__init__.py | 4 +- python/paddle/fluid/dygraph/base.py | 179 +++++++++++++++--- .../unittests/test_imperative_decorator.py | 147 ++++++++++++++ python/paddle/framework/__init__.py | 2 - python/paddle/framework/framework.py | 65 ------- 6 files changed, 307 insertions(+), 91 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3e78b716faa..8f5e26bc7c4 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -340,6 +340,7 @@ from .framework import CustomPlace # noqa: F401 from .autograd import grad # noqa: F401 from .autograd import no_grad # noqa: F401 +from .autograd import enable_grad # noqa:F401 from .autograd import set_grad_enabled # noqa: F401 from .autograd import is_grad_enabled # noqa: F401 from .framework import save # noqa: F401 diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index bf4d3e117c4..f5435b081e3 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -13,8 +13,10 @@ # limitations under the License. from ..fluid.dygraph.base import grad # noqa: F401 +from ..fluid.dygraph.base import enable_grad # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401 +from ..fluid.dygraph.base import is_grad_enabled # noqa: F401 +from ..fluid.dygraph.base import set_grad_enabled # noqa: F401 from . import backward_mode # noqa: F401 from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer as PyLayer # noqa: F401 diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index df500a12978..e404733530c 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -343,7 +343,114 @@ def no_grad(func=None): return __impl__(func) -class no_grad_: +class _DecoratorContextManager: + """Allow a context manager to be used as a decorator""" + + def __call__(self, func): + @decorator.decorator + def _decorate_function(func, *args, **kwargs): + with self: + return func(*args, **kwargs) + + @decorator.decorator + def _decorate_generator(func, *args, **kwargs): + gen = func(*args, **kwargs) + with self: + for x in gen: + yield x + + if inspect.isgeneratorfunction(func): + return _decorate_generator(func) + else: + return _decorate_function(func) + + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_value, traceback): + raise NotImplementedError + + def clone(self): + # override this method if your children class takes __init__ parameters + return self.__class__() + + +def is_grad_enabled(): + """ + Returns whether current dygraph gradient calculation mode is enabled. + + Returns: + bool: True if current dygraph gradient calculation mode is enabled, otherwise false. + + Examples: + .. code-block:: python + + import paddle + + # Dygraph gradient calculation mode is enabled by default. + paddle.is_grad_enabled() # True + + with paddle.set_grad_enabled(False): + paddle.is_grad_enabled() # False + + paddle.enable_static() + paddle.is_grad_enabled() # False + """ + tracer = framework._dygraph_tracer() + return tracer._has_grad if tracer else False + + +def _set_grad_enabled(mode): + tracer = framework._dygraph_tracer() + if tracer: + tracer._has_grad = mode + + +class set_grad_enabled(_DecoratorContextManager): + """ + Create a context which enables or disables dygraph gradient calculation. + + Args: + mode(bool): whether to enable (`True`), or disable (`False`) grad. + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([1.], stop_gradient=False) + is_train = False + with paddle.set_grad_enabled(is_train): + y = x * 2 + assert(y.stop_gradient == True) + + paddle.set_grad_enabled(True) + y = x * 2 + assert(y.stop_gradient == False) + + paddle.set_grad_enabled(False) + y = x * 2 + assert(y.stop_gradient == True) + """ + + def __init__(self, mode): + self.prev = is_grad_enabled() + _set_grad_enabled(mode) + self.mode = mode + + def __enter__(self): + ... + + def __exit__(self, *args): + _set_grad_enabled(self.prev) + + def clone(self): + return self.__class__(self.mode) + + +class no_grad_(_DecoratorContextManager): """ :api_attr: imperative @@ -389,34 +496,60 @@ class no_grad_: test_layer() """ - def __call__(self, func): - @decorator.decorator - def _decorate_function(func, *args, **kwargs): - with self: - return func(*args, **kwargs) + def __enter__(self): + self.prev = is_grad_enabled() + _set_grad_enabled(False) - @decorator.decorator - def _decorate_generator(func, *args, **kwargs): - gen = func(*args, **kwargs) - with self: - for x in gen: - yield x + def __exit__(self, *args): + _set_grad_enabled(self.prev) - if inspect.isgeneratorfunction(func): - return _decorate_generator(func) - else: - return _decorate_function(func) + +class enable_grad(_DecoratorContextManager): + """ + :api_attr: imperative + + Create a context which enable dygraph gradient calculation, + if it has been disabled by `no_grad` or `set_grad_enabled`. + + In this mode, the result of every computation will have `stop_gradient` set + to `False`. + + Also functions as a decorator. (Make sure to use an instance.) + + Examples: + + .. code-block:: python + + import paddle + + # use as generator + + x = paddle.to_tensor([1.], stop_gradient=False) + with paddle.no_grad(): + with paddle.enable_grad(): + y = x * 2 + assert(y.stop_gradient == False) + y.backward() + assert(x.grad is not None) + + # use as decorator + + @paddle.enable_grad() + def double(x): + return x * 2 + + with paddle.no_grad(): + z = double(x) + + assert(z.stop_gradient == False) + """ def __enter__(self): - tracer = framework._dygraph_tracer() - if tracer: - self.orig = tracer._has_grad - tracer._has_grad = False + self.prev = is_grad_enabled() + _set_grad_enabled(True) def __exit__(self, *args): - tracer = framework._dygraph_tracer() - if tracer: - tracer._has_grad = self.orig + _set_grad_enabled(self.prev) @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 8293503f83c..07c11514787 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -90,6 +90,7 @@ class TestNoGradClass(unittest.TestCase): self.tracer = framework._dygraph_tracer() self.tracer._train_mode = True + self.tracer._has_grad = True self.assertEqual(self.no_grad_func(1), 1) self.assertEqual(self.no_grad_func.__name__, "no_grad_func") @@ -123,5 +124,151 @@ class TestNoGradClass(unittest.TestCase): self.assertEqual(a, b) +class TestEnableGradClass(unittest.TestCase): + @paddle.enable_grad() + def enable_grad_func(self, a): + self.assertEqual(self.tracer._train_mode, True) + self.assertEqual(self.tracer._has_grad, True) + return a + + def test_main(self): + paddle.disable_static() + + self.tracer = framework._dygraph_tracer() + self.tracer._train_mode = True + self.tracer._has_grad = False + + self.assertEqual(self.enable_grad_func(1), 1) + self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func") + + def need_enable_grad_func(a, b=1): + return a + b + + decorated_func = paddle.enable_grad()(need_enable_grad_func) + self.assertEqual( + str(inspect.getfullargspec(decorated_func)), + str(inspect.getfullargspec(need_enable_grad_func)), + ) + + def test_gen(): + for i in range(3): + yield i + + a = 0 + for i in test_gen(): + a += i + + @paddle.enable_grad() + def test_wrapped_gen(): + for i in range(3): + yield i + + b = 0 + for i in test_wrapped_gen(): + b += i + + self.assertEqual(a, b) + + def test_stop_gradient(self): + x = paddle.to_tensor([1.0], stop_gradient=False) + with paddle.no_grad(): + with paddle.enable_grad(): + y = x * 2 + self.assertTrue(y.stop_gradient is False) + y.backward() + self.assertTrue(x.grad is not None) + + # use as decorator + @paddle.enable_grad() + def double(x): + return x * 2 + + with paddle.no_grad(): + z = double(x) + + self.assertTrue(z.stop_gradient is False) + + +class TestSetGradEnabledClass(unittest.TestCase): + @paddle.set_grad_enabled(True) + def enable_grad_func(self, a): + self.assertEqual(self.tracer._train_mode, True) + self.assertEqual(self.tracer._has_grad, True) + return a + + def test_main(self): + paddle.disable_static() + + self.tracer = framework._dygraph_tracer() + self.tracer._train_mode = True + + self.assertEqual(self.enable_grad_func(1), 1) + self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func") + + def need_enable_grad_func(a, b=1): + return a + b + + decorated_func = paddle.set_grad_enabled(True)(need_enable_grad_func) + self.assertEqual( + str(inspect.getfullargspec(decorated_func)), + str(inspect.getfullargspec(need_enable_grad_func)), + ) + + def test_gen(): + for i in range(3): + yield i + + a = 0 + for i in test_gen(): + a += i + + @paddle.set_grad_enabled(True) + def test_wrapped_gen(): + for i in range(3): + yield i + + b = 0 + for i in test_wrapped_gen(): + b += i + + self.assertEqual(a, b) + + def test_stop_gradient(self): + x = paddle.to_tensor([1.0], stop_gradient=False) + is_train = False + with paddle.set_grad_enabled(is_train): + y = x * 2 + self.assertTrue(y.stop_gradient is True) + + paddle.set_grad_enabled(True) + y = x * 2 + self.assertTrue(y.stop_gradient is False) + + paddle.set_grad_enabled(False) + y = x * 2 + self.assertTrue(y.stop_gradient is True) + + +class TestIsGradEnabledClass(unittest.TestCase): + def test_main(self): + paddle.disable_static() + + self.tracer = framework._dygraph_tracer() + self.tracer._train_mode = True + self.tracer._has_grad = True + + # Dygraph gradient calculation mode is enabled by default. + flag = paddle.is_grad_enabled() + self.assertTrue(flag is True) + + with paddle.set_grad_enabled(False): + flag = paddle.is_grad_enabled() + self.assertTrue(flag is False) + + flag = paddle.is_grad_enabled() + self.assertTrue(flag is True) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 7702b5e61fa..f9d8181176a 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -18,8 +18,6 @@ from . import random # noqa: F401 from .random import seed # noqa: F401 from .framework import get_default_dtype # noqa: F401 from .framework import set_default_dtype # noqa: F401 -from .framework import set_grad_enabled # noqa: F401 -from .framework import is_grad_enabled # noqa: F401 from ..fluid.param_attr import ParamAttr # noqa: F401 from ..fluid.core import CPUPlace # noqa: F401 diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index e3b7519c4f8..cbcee9ead70 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager import numpy as np -from paddle.fluid.framework import _dygraph_tracer - # TODO: define framework api from paddle.fluid.layer_helper_base import LayerHelperBase @@ -83,65 +80,3 @@ def get_default_dtype(): paddle.get_default_dtype() """ return LayerHelperBase.get_default_dtype() - - -@contextmanager -def set_grad_enabled(mode): - """ - Create a context which enables or disables dygraph gradient calculation. - - Args: - mode(bool): whether to enable (`True`), or disable (`False`) grad. - - Returns: - None. - - Examples: - .. code-block:: python - - import paddle - x = paddle.ones([3, 2]) - x.stop_gradient = False - with paddle.set_grad_enabled(False): - y = x * 2 - with paddle.set_grad_enabled(True): - z = x * 2 - print(y.stop_gradient) # True - print(z.stop_gradient) # False - """ - - tracer = _dygraph_tracer() - if tracer: - prev_mode = tracer._has_grad - tracer._has_grad = mode - try: - yield - finally: - tracer._has_grad = prev_mode - else: - yield - - -def is_grad_enabled(): - """ - Returns whether current dygraph gradient calculation mode is enabled. - - Returns: - bool: True if current dygraph gradient calculation mode is enabled, otherwise false. - - Examples: - .. code-block:: python - - import paddle - - # Dygraph gradient calculation mode is enabled by default. - paddle.is_grad_enabled() # True - - with paddle.set_grad_enabled(False): - paddle.is_grad_enabled() # False - - paddle.enable_static() - paddle.is_grad_enabled() # False - """ - tracer = _dygraph_tracer() - return tracer._has_grad if tracer else False -- GitLab