未验证 提交 617eb67f 编写于 作者: Y Yang Zhang 提交者: GitHub

Upgrade `no_grad` decorator (#25472)

* Upgrade `no_grad` decorator

test=develop

- match torch decorator usage (i.e., with parenthesis)
- handle generator functions
- add `paddle.no_grad` alias

* Switch from `functools` to `decorator`

preserves signature

* Reword decorator usage note
上级 7a89a0a7
...@@ -239,6 +239,7 @@ from .incubate import hapi ...@@ -239,6 +239,7 @@ from .incubate import hapi
from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS
from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS
from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad #DEFINE_ALIAS
from . import jit from . import jit
from . import static from . import static
...@@ -129,7 +129,7 @@ class GradientClipBase(object): ...@@ -129,7 +129,7 @@ class GradientClipBase(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
@imperative_base.no_grad @imperative_base.no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
raise NotImplementedError raise NotImplementedError
...@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase): ...@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max) return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad @imperative_base.no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
for p, g in params_grads: for p, g in params_grads:
...@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase): ...@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad @imperative_base.no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
for p, g in params_grads: for p, g in params_grads:
...@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad @imperative_base.no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import inspect
import decorator import decorator
import contextlib import contextlib
import functools
import sys import sys
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
...@@ -172,28 +172,15 @@ def disable_dygraph(): ...@@ -172,28 +172,15 @@ def disable_dygraph():
_functional_dygraph_context_manager = None _functional_dygraph_context_manager = None
@signature_safe_contextmanager class no_grad:
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield
def no_grad(func=None):
""" """
:api_attr: imperative :api_attr: imperative
Create a context which disables dygraph gradient calculation. Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`. In this mode, the result of every computation will have `stop_gradient` set
to `True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.) Also functions as a decorator. (Make sure to use an instance.)
Examples: Examples:
...@@ -202,47 +189,65 @@ def no_grad(func=None): ...@@ -202,47 +189,65 @@ def no_grad(func=None):
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
paddle.enable_imperative()
# use as generator # use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32') data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard(): l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None l1 = fluid.Linear(2, 2)
l1 = fluid.Linear(2, 2) with fluid.no_grad():
with fluid.dygraph.no_grad(): # l1.weight.stop_gradient is False
# l1.weight.stop_gradient is False tmp = l1.weight * 2 # tmp.stop_gradient is True
tmp = l1.weight * 2 # tmp.stop_gradient is True x = fluid.dygraph.to_variable(data)
x = fluid.dygraph.to_variable(data) y = l0(x) + tmp
y = l0(x) + tmp o = l1(y)
o = l1(y) o.backward()
o.backward() print(tmp.gradient() is None) # True
print(tmp.gradient() is None) # True print(l0.weight.gradient() is None) # False
print(l0.weight.gradient() is None) # False
# use as decorator # use as decorator
@fluid.dygraph.no_grad @fluid.no_grad()
def test_layer(): def test_layer():
with fluid.dygraph.guard(): inp = np.ones([3, 1024], dtype='float32')
inp = np.ones([3, 1024], dtype='float32') t = fluid.dygraph.base.to_variable(inp)
t = fluid.dygraph.base.to_variable(inp) linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear1 = fluid.Linear(1024, 4, bias_attr=False) linear2 = fluid.Linear(4, 4)
linear2 = fluid.Linear(4, 4) ret = linear1(t)
ret = linear1(t) dy_ret = linear2(ret)
dy_ret = linear2(ret)
test_layer() test_layer()
""" """
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
def __call__(self, func):
@decorator.decorator @decorator.decorator
def __impl__(func, *args, **kwargs): def _decorate_function(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with self:
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__(func) @decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x
if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)
def __enter__(self):
tracer = framework._dygraph_tracer()
if tracer:
self.orig = tracer._train_mode
tracer._train_mode = False
def __exit__(self, *args):
tracer = framework._dygraph_tracer()
if tracer:
tracer._train_mode = self.orig
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -37,7 +37,7 @@ def monkey_patch_math_varbase(): ...@@ -37,7 +37,7 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance. The difference is, in dygraph mode, use auto-generated op functions for better performance.
""" """
@no_grad @no_grad()
def create_tensor(value, dtype, shape): def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape, out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
......
...@@ -380,7 +380,7 @@ class DataParallel(layers.Layer): ...@@ -380,7 +380,7 @@ class DataParallel(layers.Layer):
self._reshape_inplace(x=g_var, shape=g_shape) self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape assert g_var.shape == g_shape
@no_grad @no_grad()
def apply_collective_grads(self): def apply_collective_grads(self):
""" """
AllReduce the Parameters' gradient. AllReduce the Parameters' gradient.
......
...@@ -60,7 +60,7 @@ class Optimizer(object): ...@@ -60,7 +60,7 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
@imperative_base.no_grad @imperative_base.no_grad()
def __init__(self, def __init__(self,
learning_rate, learning_rate,
parameter_list=None, parameter_list=None,
...@@ -863,7 +863,7 @@ class Optimizer(object): ...@@ -863,7 +863,7 @@ class Optimizer(object):
if p.trainable: if p.trainable:
p.clear_gradient() p.clear_gradient()
@imperative_base.no_grad @imperative_base.no_grad()
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -981,7 +981,7 @@ class SGDOptimizer(Optimizer): ...@@ -981,7 +981,7 @@ class SGDOptimizer(Optimizer):
name=name) name=name)
self.type = "sgd" self.type = "sgd"
@no_grad @no_grad()
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
...@@ -1518,7 +1518,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -1518,7 +1518,7 @@ class DGCMomentumOptimizer(Optimizer):
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(), dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name]) [param_var.name, grad_var.name])
@imperative_base.no_grad @imperative_base.no_grad()
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \ params_grads, table_param_and_grad, table_optimize_op = \
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import unittest import unittest
...@@ -27,7 +28,7 @@ class TestTracerMode(unittest.TestCase): ...@@ -27,7 +28,7 @@ class TestTracerMode(unittest.TestCase):
def get_tracer_mode(self): def get_tracer_mode(self):
assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled" assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled"
@fluid.dygraph.no_grad @paddle.no_grad()
def no_grad_func(self, a): def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False) self.assertEqual(self.tracer._train_mode, False)
return a return a
...@@ -55,13 +56,32 @@ class TestTracerMode(unittest.TestCase): ...@@ -55,13 +56,32 @@ class TestTracerMode(unittest.TestCase):
def need_no_grad_func(a, b=1): def need_no_grad_func(a, b=1):
return a + b return a + b
decorated_func = fluid.dygraph.no_grad(need_no_grad_func) decorated_func = paddle.no_grad()(need_no_grad_func)
self.assertTrue( self.assertTrue(
str(inspect.getargspec(decorated_func)) == str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func))) str(inspect.getargspec(need_no_grad_func)))
self.assertEqual(self.tracer._train_mode, self.init_mode) self.assertEqual(self.tracer._train_mode, self.init_mode)
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self.check_not_support_rlt(False) self.check_not_support_rlt(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册