diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 2c55dc8951a3733c9a3ccffbad1cb3ad9a83511a..92ddb3223dfea687bb0778fb23e8e2f95d9e92f4 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -146,12 +146,12 @@ def _switch_tracer_mode_guard_(is_train=True): yield -def _no_grad_(func): +def no_grad(func=None): """ - This Decorator will avoid the func being decorated creating backward network in dygraph mode + Create a context which disables dygraph gradient calculation. + In this mode, the result of every computation will have `stop_gradient=True`. - Parameter: - - **func** (python func): the func don't need grad + Also functions as a decorator. (Make sure to instantiate without parenthesis.) Examples: @@ -160,6 +160,24 @@ def _no_grad_(func): import numpy as np import paddle.fluid as fluid + # use as generator + + data = np.array([[2, 3], [4, 5]]).astype('float32') + with fluid.dygraph.guard(): + l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None + l1 = fluid.Linear(2, 2) + with fluid.dygraph.no_grad(): + # l1.weight.stop_gradient is False + tmp = l1.weight * 2 # tmp.stop_gradient is True + x = fluid.dygraph.to_variable(data) + y = l0(x) + tmp + o = l1(y) + o.backward() + print(tmp.gradient() is None) # True + print(l0.weight.gradient() is None) # False + + # use as decorator + @fluid.dygraph.no_grad def test_layer(): with fluid.dygraph.guard(): @@ -173,17 +191,15 @@ def _no_grad_(func): test_layer() """ + if func is None: + return _switch_tracer_mode_guard_(is_train=False) + else: - def __impl__(*args, **kwargs): - with _switch_tracer_mode_guard_(is_train=False): - return func(*args, **kwargs) - - return __impl__ - + def __impl__(*args, **kwargs): + with _switch_tracer_mode_guard_(is_train=False): + return func(*args, **kwargs) -no_grad = wrap_decorator(_no_grad_) -# for fluidDoc -no_grad.__doc__ = _no_grad_.__doc__ + return __impl__ @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 32c887a618f3ed5df46c0bfd06fc38c417845487..47831341c489e1741923a6fe0938618fc5a508c3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -227,6 +227,24 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.array_equal(y, tmp4.numpy())) self.assertTrue(np.array_equal(x, tmp5.numpy())) + def test_no_grad_guard(self): + data = np.array([[2, 3], [4, 5]]).astype('float32') + with fluid.dygraph.guard(): + l0 = fluid.Linear(2, 2) + self.assertTrue(l0.weight._grad_ivar() is None) + l1 = fluid.Linear(2, 2) + with fluid.dygraph.no_grad(): + self.assertTrue(l1.weight.stop_gradient is False) + tmp = l1.weight * 2 + self.assertTrue(tmp.stop_gradient) + x = fluid.dygraph.to_variable(data) + y = l0(x) + tmp + o = l1(y) + o.backward() + + self.assertTrue(tmp._grad_ivar() is None) + self.assertTrue(l0.weight._grad_ivar() is not None) + def test_sum_op(self): x = np.ones([2, 2], np.float32) with fluid.dygraph.guard():