From 617eb67f29c149699fc0bfa799c213c0cafcf839 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Fri, 14 Aug 2020 13:36:30 +0800 Subject: [PATCH] Upgrade `no_grad` decorator (#25472) * Upgrade `no_grad` decorator test=develop - match torch decorator usage (i.e., with parenthesis) - handle generator functions - add `paddle.no_grad` alias * Switch from `functools` to `decorator` preserves signature * Reword decorator usage note --- python/paddle/__init__.py | 1 + python/paddle/fluid/clip.py | 8 +- python/paddle/fluid/dygraph/base.py | 95 ++++++++++--------- python/paddle/fluid/dygraph/math_op_patch.py | 2 +- python/paddle/fluid/dygraph/parallel.py | 2 +- python/paddle/fluid/optimizer.py | 8 +- .../unittests/test_imperative_decorator.py | 24 ++++- 7 files changed, 83 insertions(+), 57 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 14824407284..bb4a4bd2486 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -239,6 +239,7 @@ from .incubate import hapi from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS +from .fluid.dygraph.base import no_grad #DEFINE_ALIAS from . import jit from . import static diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 7b301ac19d1..5f6594a4721 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -129,7 +129,7 @@ class GradientClipBase(object): def __str__(self): raise NotImplementedError() - @imperative_base.no_grad + @imperative_base.no_grad() def _dygraph_clip(self, params_grads): raise NotImplementedError @@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase): def __str__(self): return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max) - @imperative_base.no_grad + @imperative_base.no_grad() def _dygraph_clip(self, params_grads): params_and_grads = [] for p, g in params_grads: @@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase): def __str__(self): return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm - @imperative_base.no_grad + @imperative_base.no_grad() def _dygraph_clip(self, params_grads): params_and_grads = [] for p, g in params_grads: @@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase): def __str__(self): return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) - @imperative_base.no_grad + @imperative_base.no_grad() def _dygraph_clip(self, params_grads): params_and_grads = [] sum_square_list = [] diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 826de0588ef..9eef4719cbd 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator +import inspect import decorator import contextlib -import functools import sys import numpy as np from paddle.fluid import core @@ -172,28 +172,15 @@ def disable_dygraph(): _functional_dygraph_context_manager = None -@signature_safe_contextmanager -def _switch_tracer_mode_guard_(is_train=True): - tracer = framework._dygraph_tracer() - if tracer: - mode = tracer._train_mode - tracer._train_mode = is_train - try: - yield - finally: - tracer._train_mode = mode - else: - yield - - -def no_grad(func=None): +class no_grad: """ :api_attr: imperative Create a context which disables dygraph gradient calculation. - In this mode, the result of every computation will have `stop_gradient=True`. + In this mode, the result of every computation will have `stop_gradient` set + to `True`. - Also functions as a decorator. (Make sure to instantiate without parenthesis.) + Also functions as a decorator. (Make sure to use an instance.) Examples: @@ -202,47 +189,65 @@ def no_grad(func=None): import numpy as np import paddle.fluid as fluid + paddle.enable_imperative() + # use as generator data = np.array([[2, 3], [4, 5]]).astype('float32') - with fluid.dygraph.guard(): - l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None - l1 = fluid.Linear(2, 2) - with fluid.dygraph.no_grad(): - # l1.weight.stop_gradient is False - tmp = l1.weight * 2 # tmp.stop_gradient is True - x = fluid.dygraph.to_variable(data) - y = l0(x) + tmp - o = l1(y) - o.backward() - print(tmp.gradient() is None) # True - print(l0.weight.gradient() is None) # False + l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None + l1 = fluid.Linear(2, 2) + with fluid.no_grad(): + # l1.weight.stop_gradient is False + tmp = l1.weight * 2 # tmp.stop_gradient is True + x = fluid.dygraph.to_variable(data) + y = l0(x) + tmp + o = l1(y) + o.backward() + print(tmp.gradient() is None) # True + print(l0.weight.gradient() is None) # False # use as decorator - @fluid.dygraph.no_grad + @fluid.no_grad() def test_layer(): - with fluid.dygraph.guard(): - inp = np.ones([3, 1024], dtype='float32') - t = fluid.dygraph.base.to_variable(inp) - linear1 = fluid.Linear(1024, 4, bias_attr=False) - linear2 = fluid.Linear(4, 4) - ret = linear1(t) - dy_ret = linear2(ret) + inp = np.ones([3, 1024], dtype='float32') + t = fluid.dygraph.base.to_variable(inp) + linear1 = fluid.Linear(1024, 4, bias_attr=False) + linear2 = fluid.Linear(4, 4) + ret = linear1(t) + dy_ret = linear2(ret) test_layer() - """ - if func is None: - return _switch_tracer_mode_guard_(is_train=False) - else: + def __call__(self, func): @decorator.decorator - def __impl__(func, *args, **kwargs): - with _switch_tracer_mode_guard_(is_train=False): + def _decorate_function(func, *args, **kwargs): + with self: return func(*args, **kwargs) - return __impl__(func) + @decorator.decorator + def _decorate_generator(func, *args, **kwargs): + gen = func(*args, **kwargs) + with self: + for x in gen: + yield x + + if inspect.isgeneratorfunction(func): + return _decorate_generator(func) + else: + return _decorate_function(func) + + def __enter__(self): + tracer = framework._dygraph_tracer() + if tracer: + self.orig = tracer._train_mode + tracer._train_mode = False + + def __exit__(self, *args): + tracer = framework._dygraph_tracer() + if tracer: + tracer._train_mode = self.orig @signature_safe_contextmanager diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index d2c779a8549..4ee5e9895a7 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -37,7 +37,7 @@ def monkey_patch_math_varbase(): The difference is, in dygraph mode, use auto-generated op functions for better performance. """ - @no_grad + @no_grad() def create_tensor(value, dtype, shape): out = _varbase_creator(dtype=dtype) out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape, diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 804076f608e..24e7f64cebf 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -380,7 +380,7 @@ class DataParallel(layers.Layer): self._reshape_inplace(x=g_var, shape=g_shape) assert g_var.shape == g_shape - @no_grad + @no_grad() def apply_collective_grads(self): """ AllReduce the Parameters' gradient. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c84d2ac3796..740db0d4b9e 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -60,7 +60,7 @@ class Optimizer(object): but need to use one of it's implementation. """ - @imperative_base.no_grad + @imperative_base.no_grad() def __init__(self, learning_rate, parameter_list=None, @@ -863,7 +863,7 @@ class Optimizer(object): if p.trainable: p.clear_gradient() - @imperative_base.no_grad + @imperative_base.no_grad() def minimize(self, loss, startup_program=None, @@ -981,7 +981,7 @@ class SGDOptimizer(Optimizer): name=name) self.type = "sgd" - @no_grad + @no_grad() def _append_optimize_op(self, block, param_and_grad): lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): @@ -1518,7 +1518,7 @@ class DGCMomentumOptimizer(Optimizer): dgc_op._set_attr(op_maker.kOpRoleVarAttrName(), [param_var.name, grad_var.name]) - @imperative_base.no_grad + @imperative_base.no_grad() def apply_gradients(self, params_grads): params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads, table_param_and_grad, table_optimize_op = \ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 82e81d72f9a..820206a3ce6 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework import unittest @@ -27,7 +28,7 @@ class TestTracerMode(unittest.TestCase): def get_tracer_mode(self): assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled" - @fluid.dygraph.no_grad + @paddle.no_grad() def no_grad_func(self, a): self.assertEqual(self.tracer._train_mode, False) return a @@ -55,13 +56,32 @@ class TestTracerMode(unittest.TestCase): def need_no_grad_func(a, b=1): return a + b - decorated_func = fluid.dygraph.no_grad(need_no_grad_func) + decorated_func = paddle.no_grad()(need_no_grad_func) self.assertTrue( str(inspect.getargspec(decorated_func)) == str(inspect.getargspec(need_no_grad_func))) self.assertEqual(self.tracer._train_mode, self.init_mode) + def test_gen(): + for i in range(3): + yield i + + a = 0 + for i in test_gen(): + a += i + + @paddle.no_grad() + def test_wrapped_gen(): + for i in range(3): + yield i + + b = 0 + for i in test_wrapped_gen(): + b += i + + self.assertEqual(a, b) + with fluid.dygraph.guard(): self.check_not_support_rlt(False) -- GitLab