From 3a6ead24ad72ad82aa383b748ffff4d34f28f88d Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 3 Jun 2019 14:23:26 +0800 Subject: [PATCH] Add no_grad decorator to dygraph (#17790) * add no_grad decorator to dygraph, test=develop * add unittest,test=develop --- python/paddle/fluid/dygraph/base.py | 33 ++++++++++++- python/paddle/fluid/dygraph_grad_clip.py | 5 ++ python/paddle/fluid/optimizer.py | 2 + .../unittests/test_imperative_decorator.py | 48 +++++++++++++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_decorator.py diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index bf484b35c7b..c19a2f926e0 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -11,20 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ..wrapped_decorator import signature_safe_contextmanager +from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator +import contextlib import numpy as np from paddle.fluid import core from paddle.fluid import framework from .tracer import Tracer -__all__ = ['enabled', 'guard', 'to_variable'] +__all__ = [ + 'enabled', + 'no_grad', + 'guard', + 'to_variable', +] def enabled(): return framework.in_dygraph_mode() +@contextlib.contextmanager +def _switch_tracer_mode_guard_(is_train=True): + tracer = framework._dygraph_tracer() + if tracer: + mode = tracer._train_mode + tracer._train_mode = is_train + yield + tracer._train_mode = mode + else: + yield + + +def _no_grad_(func): + def __impl__(*args, **kwargs): + with _switch_tracer_mode_guard_(is_train=False): + return func(*args, **kwargs) + + return __impl__ + + +no_grad = wrap_decorator(_no_grad_) + + @signature_safe_contextmanager def guard(place=None): train = framework.Program() diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py index bcc307511ec..826f918f36e 100644 --- a/python/paddle/fluid/dygraph_grad_clip.py +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -22,6 +22,7 @@ import functools from . import layers from . import framework from . import core +from .dygraph import base as imperative_base __all__ = [ 'GradClipByValue', @@ -37,6 +38,7 @@ class GradClipBase(object): def _clip(self, para_and_grad): raise NotImplementedError + @imperative_base.no_grad def __call__(self, para_and_grad): return self._clip(para_and_grad) @@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase): """ + @imperative_base.no_grad def __init__(self, min_value, max_value=None): if min_value is None: @@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase): """ + @imperative_base.no_grad def __init__(self, clip_norm): self.clip_norm = clip_norm @@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase): """ + @imperative_base.no_grad def __init__(self, max_global_norm): self.max_global_norm = layers.fill_constant( shape=[1], dtype='float32', value=max_global_norm) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c3d136c29a3..50705bba2e9 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -55,6 +55,7 @@ class Optimizer(object): but need to use one of it's implementation. """ + @imperative_base.no_grad def __init__(self, learning_rate, regularization=None, name=None): if framework.in_dygraph_mode(): if not isinstance(learning_rate, float) and \ @@ -472,6 +473,7 @@ class Optimizer(object): optimize_ops = self.apply_gradients(params_grads) return optimize_ops + @imperative_base.no_grad def minimize(self, loss, startup_program=None, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py new file mode 100644 index 00000000000..2b92bcb38be --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import unittest + + +class TestTracerMode(unittest.TestCase): + def setUp(self): + self.init_mode = True + + def get_tracer_mode(self): + assert fluid.dygraph.enabled(), "Dygraph mode must be enabled" + + @fluid.dygraph.no_grad + def no_grad_func(self, a): + self.assertEqual(self.tracer._train_mode, False) + return a + + def test_main(self): + with fluid.dygraph.guard(): + self.tracer = framework._dygraph_tracer() + self.tracer._train_mode = self.init_mode + + self.assertEqual(self.no_grad_func(1), 1) + + self.assertEqual(self.tracer._train_mode, self.init_mode) + + +class TestTracerMode2(TestTracerMode): + def setUp(self): + self.init_mode = False + + +if __name__ == '__main__': + unittest.main() -- GitLab