未验证 提交 499b7f87 编写于 作者: H Hui Zhang 提交者: GitHub

add enable_grad, reafctor no_grad, set_grad_enabled (#50560)

* add enable_grad, reafctor no_grad, set_grad_enabled

* fix bug

* fix bug

* format

* fix bug

* format

* format

* fix doc

* fix

* fix

* fix bug

* fix comment
上级 a35dbc29
......@@ -340,6 +340,7 @@ from .framework import CustomPlace # noqa: F401
from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401
from .autograd import enable_grad # noqa:F401
from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
......
......@@ -13,8 +13,10 @@
# limitations under the License.
from ..fluid.dygraph.base import grad # noqa: F401
from ..fluid.dygraph.base import enable_grad # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer as PyLayer # noqa: F401
......
......@@ -343,7 +343,114 @@ def no_grad(func=None):
return __impl__(func)
class no_grad_:
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func):
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
@decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x
if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_value, traceback):
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()
def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = framework._dygraph_tracer()
return tracer._has_grad if tracer else False
def _set_grad_enabled(mode):
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = mode
class set_grad_enabled(_DecoratorContextManager):
"""
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Returns:
None.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.], stop_gradient=False)
is_train = False
with paddle.set_grad_enabled(is_train):
y = x * 2
assert(y.stop_gradient == True)
paddle.set_grad_enabled(True)
y = x * 2
assert(y.stop_gradient == False)
paddle.set_grad_enabled(False)
y = x * 2
assert(y.stop_gradient == True)
"""
def __init__(self, mode):
self.prev = is_grad_enabled()
_set_grad_enabled(mode)
self.mode = mode
def __enter__(self):
...
def __exit__(self, *args):
_set_grad_enabled(self.prev)
def clone(self):
return self.__class__(self.mode)
class no_grad_(_DecoratorContextManager):
"""
:api_attr: imperative
......@@ -389,34 +496,60 @@ class no_grad_:
test_layer()
"""
def __call__(self, func):
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
def __enter__(self):
self.prev = is_grad_enabled()
_set_grad_enabled(False)
@decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x
def __exit__(self, *args):
_set_grad_enabled(self.prev)
if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)
class enable_grad(_DecoratorContextManager):
"""
:api_attr: imperative
Create a context which enable dygraph gradient calculation,
if it has been disabled by `no_grad` or `set_grad_enabled`.
In this mode, the result of every computation will have `stop_gradient` set
to `False`.
Also functions as a decorator. (Make sure to use an instance.)
Examples:
.. code-block:: python
import paddle
# use as generator
x = paddle.to_tensor([1.], stop_gradient=False)
with paddle.no_grad():
with paddle.enable_grad():
y = x * 2
assert(y.stop_gradient == False)
y.backward()
assert(x.grad is not None)
# use as decorator
@paddle.enable_grad()
def double(x):
return x * 2
with paddle.no_grad():
z = double(x)
assert(z.stop_gradient == False)
"""
def __enter__(self):
tracer = framework._dygraph_tracer()
if tracer:
self.orig = tracer._has_grad
tracer._has_grad = False
self.prev = is_grad_enabled()
_set_grad_enabled(True)
def __exit__(self, *args):
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = self.orig
_set_grad_enabled(self.prev)
@signature_safe_contextmanager
......
......@@ -90,6 +90,7 @@ class TestNoGradClass(unittest.TestCase):
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = True
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
......@@ -123,5 +124,151 @@ class TestNoGradClass(unittest.TestCase):
self.assertEqual(a, b)
class TestEnableGradClass(unittest.TestCase):
@paddle.enable_grad()
def enable_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, True)
return a
def test_main(self):
paddle.disable_static()
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = False
self.assertEqual(self.enable_grad_func(1), 1)
self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func")
def need_enable_grad_func(a, b=1):
return a + b
decorated_func = paddle.enable_grad()(need_enable_grad_func)
self.assertEqual(
str(inspect.getfullargspec(decorated_func)),
str(inspect.getfullargspec(need_enable_grad_func)),
)
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.enable_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
def test_stop_gradient(self):
x = paddle.to_tensor([1.0], stop_gradient=False)
with paddle.no_grad():
with paddle.enable_grad():
y = x * 2
self.assertTrue(y.stop_gradient is False)
y.backward()
self.assertTrue(x.grad is not None)
# use as decorator
@paddle.enable_grad()
def double(x):
return x * 2
with paddle.no_grad():
z = double(x)
self.assertTrue(z.stop_gradient is False)
class TestSetGradEnabledClass(unittest.TestCase):
@paddle.set_grad_enabled(True)
def enable_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, True)
return a
def test_main(self):
paddle.disable_static()
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.assertEqual(self.enable_grad_func(1), 1)
self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func")
def need_enable_grad_func(a, b=1):
return a + b
decorated_func = paddle.set_grad_enabled(True)(need_enable_grad_func)
self.assertEqual(
str(inspect.getfullargspec(decorated_func)),
str(inspect.getfullargspec(need_enable_grad_func)),
)
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.set_grad_enabled(True)
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
def test_stop_gradient(self):
x = paddle.to_tensor([1.0], stop_gradient=False)
is_train = False
with paddle.set_grad_enabled(is_train):
y = x * 2
self.assertTrue(y.stop_gradient is True)
paddle.set_grad_enabled(True)
y = x * 2
self.assertTrue(y.stop_gradient is False)
paddle.set_grad_enabled(False)
y = x * 2
self.assertTrue(y.stop_gradient is True)
class TestIsGradEnabledClass(unittest.TestCase):
def test_main(self):
paddle.disable_static()
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = True
# Dygraph gradient calculation mode is enabled by default.
flag = paddle.is_grad_enabled()
self.assertTrue(flag is True)
with paddle.set_grad_enabled(False):
flag = paddle.is_grad_enabled()
self.assertTrue(flag is False)
flag = paddle.is_grad_enabled()
self.assertTrue(flag is True)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -18,8 +18,6 @@ from . import random # noqa: F401
from .random import seed # noqa: F401
from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401
from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.core import CPUPlace # noqa: F401
......
......@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import numpy as np
from paddle.fluid.framework import _dygraph_tracer
# TODO: define framework api
from paddle.fluid.layer_helper_base import LayerHelperBase
......@@ -83,65 +80,3 @@ def get_default_dtype():
paddle.get_default_dtype()
"""
return LayerHelperBase.get_default_dtype()
@contextmanager
def set_grad_enabled(mode):
"""
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Returns:
None.
Examples:
.. code-block:: python
import paddle
x = paddle.ones([3, 2])
x.stop_gradient = False
with paddle.set_grad_enabled(False):
y = x * 2
with paddle.set_grad_enabled(True):
z = x * 2
print(y.stop_gradient) # True
print(z.stop_gradient) # False
"""
tracer = _dygraph_tracer()
if tracer:
prev_mode = tracer._has_grad
tracer._has_grad = mode
try:
yield
finally:
tracer._has_grad = prev_mode
else:
yield
def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = _dygraph_tracer()
return tracer._has_grad if tracer else False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册