未验证 提交 c435110a 编写于 作者: S songyouwei 提交者: GitHub

support no_grad inplace operating (#22522)

* support no_grad atomic operating
test=develop

* rm create param api
test=develop

* refine sample code
test=develop

* clean impl
test=develop
上级 42655ef7
...@@ -146,12 +146,12 @@ def _switch_tracer_mode_guard_(is_train=True): ...@@ -146,12 +146,12 @@ def _switch_tracer_mode_guard_(is_train=True):
yield yield
def _no_grad_(func): def no_grad(func=None):
""" """
This Decorator will avoid the func being decorated creating backward network in dygraph mode Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Parameter: Also functions as a decorator. (Make sure to instantiate without parenthesis.)
- **func** (python func): the func don't need grad
Examples: Examples:
...@@ -160,6 +160,24 @@ def _no_grad_(func): ...@@ -160,6 +160,24 @@ def _no_grad_(func):
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad @fluid.dygraph.no_grad
def test_layer(): def test_layer():
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -173,17 +191,15 @@ def _no_grad_(func): ...@@ -173,17 +191,15 @@ def _no_grad_(func):
test_layer() test_layer()
""" """
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__
no_grad = wrap_decorator(_no_grad_) return __impl__
# for fluidDoc
no_grad.__doc__ = _no_grad_.__doc__
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -227,6 +227,24 @@ class TestImperative(unittest.TestCase): ...@@ -227,6 +227,24 @@ class TestImperative(unittest.TestCase):
self.assertTrue(np.array_equal(y, tmp4.numpy())) self.assertTrue(np.array_equal(y, tmp4.numpy()))
self.assertTrue(np.array_equal(x, tmp5.numpy())) self.assertTrue(np.array_equal(x, tmp5.numpy()))
def test_no_grad_guard(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2)
self.assertTrue(l0.weight._grad_ivar() is None)
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(l0.weight._grad_ivar() is not None)
def test_sum_op(self): def test_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册