diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index bf484b35c7bf9a2b17126789ff247bd73095fe7b..c19a2f926e068c930a2ab3dc6d610da6d6ca145b 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -11,20 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ..wrapped_decorator import signature_safe_contextmanager +from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator +import contextlib import numpy as np from paddle.fluid import core from paddle.fluid import framework from .tracer import Tracer -__all__ = ['enabled', 'guard', 'to_variable'] +__all__ = [ + 'enabled', + 'no_grad', + 'guard', + 'to_variable', +] def enabled(): return framework.in_dygraph_mode() +@contextlib.contextmanager +def _switch_tracer_mode_guard_(is_train=True): + tracer = framework._dygraph_tracer() + if tracer: + mode = tracer._train_mode + tracer._train_mode = is_train + yield + tracer._train_mode = mode + else: + yield + + +def _no_grad_(func): + def __impl__(*args, **kwargs): + with _switch_tracer_mode_guard_(is_train=False): + return func(*args, **kwargs) + + return __impl__ + + +no_grad = wrap_decorator(_no_grad_) + + @signature_safe_contextmanager def guard(place=None): train = framework.Program() diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py index bcc307511ecc6bc1625044fdb190c44335f07712..826f918f36ece2eab5ddf17c1c0b3c86ca4e6438 100644 --- a/python/paddle/fluid/dygraph_grad_clip.py +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -22,6 +22,7 @@ import functools from . import layers from . import framework from . import core +from .dygraph import base as imperative_base __all__ = [ 'GradClipByValue', @@ -37,6 +38,7 @@ class GradClipBase(object): def _clip(self, para_and_grad): raise NotImplementedError + @imperative_base.no_grad def __call__(self, para_and_grad): return self._clip(para_and_grad) @@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase): """ + @imperative_base.no_grad def __init__(self, min_value, max_value=None): if min_value is None: @@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase): """ + @imperative_base.no_grad def __init__(self, clip_norm): self.clip_norm = clip_norm @@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase): """ + @imperative_base.no_grad def __init__(self, max_global_norm): self.max_global_norm = layers.fill_constant( shape=[1], dtype='float32', value=max_global_norm) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c3d136c29a30187b35da6e172821d425d8d44f1c..50705bba2e93ece695991212f635ab7ab8010b8c 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -55,6 +55,7 @@ class Optimizer(object): but need to use one of it's implementation. """ + @imperative_base.no_grad def __init__(self, learning_rate, regularization=None, name=None): if framework.in_dygraph_mode(): if not isinstance(learning_rate, float) and \ @@ -472,6 +473,7 @@ class Optimizer(object): optimize_ops = self.apply_gradients(params_grads) return optimize_ops + @imperative_base.no_grad def minimize(self, loss, startup_program=None, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..2b92bcb38befcd7d1d2932e2681dc167fa6777ee --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import unittest + + +class TestTracerMode(unittest.TestCase): + def setUp(self): + self.init_mode = True + + def get_tracer_mode(self): + assert fluid.dygraph.enabled(), "Dygraph mode must be enabled" + + @fluid.dygraph.no_grad + def no_grad_func(self, a): + self.assertEqual(self.tracer._train_mode, False) + return a + + def test_main(self): + with fluid.dygraph.guard(): + self.tracer = framework._dygraph_tracer() + self.tracer._train_mode = self.init_mode + + self.assertEqual(self.no_grad_func(1), 1) + + self.assertEqual(self.tracer._train_mode, self.init_mode) + + +class TestTracerMode2(TestTracerMode): + def setUp(self): + self.init_mode = False + + +if __name__ == '__main__': + unittest.main()