未验证 提交 3a6ead24 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add no_grad decorator to dygraph (#17790)

* add no_grad decorator to dygraph, test=develop

* add unittest,test=develop
上级 53920f5e
...@@ -11,20 +11,49 @@ ...@@ -11,20 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .tracer import Tracer from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable'] __all__ = [
'enabled',
'no_grad',
'guard',
'to_variable',
]
def enabled(): def enabled():
return framework.in_dygraph_mode() return framework.in_dygraph_mode()
@contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
yield
tracer._train_mode = mode
else:
yield
def _no_grad_(func):
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__
no_grad = wrap_decorator(_no_grad_)
@signature_safe_contextmanager @signature_safe_contextmanager
def guard(place=None): def guard(place=None):
train = framework.Program() train = framework.Program()
......
...@@ -22,6 +22,7 @@ import functools ...@@ -22,6 +22,7 @@ import functools
from . import layers from . import layers
from . import framework from . import framework
from . import core from . import core
from .dygraph import base as imperative_base
__all__ = [ __all__ = [
'GradClipByValue', 'GradClipByValue',
...@@ -37,6 +38,7 @@ class GradClipBase(object): ...@@ -37,6 +38,7 @@ class GradClipBase(object):
def _clip(self, para_and_grad): def _clip(self, para_and_grad):
raise NotImplementedError raise NotImplementedError
@imperative_base.no_grad
def __call__(self, para_and_grad): def __call__(self, para_and_grad):
return self._clip(para_and_grad) return self._clip(para_and_grad)
...@@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase): ...@@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, min_value, max_value=None): def __init__(self, min_value, max_value=None):
if min_value is None: if min_value is None:
...@@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase): ...@@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, clip_norm): def __init__(self, clip_norm):
self.clip_norm = clip_norm self.clip_norm = clip_norm
...@@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase):
""" """
@imperative_base.no_grad
def __init__(self, max_global_norm): def __init__(self, max_global_norm):
self.max_global_norm = layers.fill_constant( self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm) shape=[1], dtype='float32', value=max_global_norm)
......
...@@ -55,6 +55,7 @@ class Optimizer(object): ...@@ -55,6 +55,7 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
@imperative_base.no_grad
def __init__(self, learning_rate, regularization=None, name=None): def __init__(self, learning_rate, regularization=None, name=None):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate, float) and \
...@@ -472,6 +473,7 @@ class Optimizer(object): ...@@ -472,6 +473,7 @@ class Optimizer(object):
optimize_ops = self.apply_gradients(params_grads) optimize_ops = self.apply_gradients(params_grads)
return optimize_ops return optimize_ops
@imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import unittest
class TestTracerMode(unittest.TestCase):
def setUp(self):
self.init_mode = True
def get_tracer_mode(self):
assert fluid.dygraph.enabled(), "Dygraph mode must be enabled"
@fluid.dygraph.no_grad
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
def test_main(self):
with fluid.dygraph.guard():
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = self.init_mode
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.tracer._train_mode, self.init_mode)
class TestTracerMode2(TestTracerMode):
def setUp(self):
self.init_mode = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册