未验证 提交 3a6ead24 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add no_grad decorator to dygraph (#17790)

* add no_grad decorator to dygraph, test=develop

* add unittest,test=develop
上级 53920f5e
......@@ -11,20 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from .tracer import Tracer
__all__ = ['enabled', 'guard', 'to_variable']
__all__ = [
'enabled',
'no_grad',
'guard',
'to_variable',
]
def enabled():
return framework.in_dygraph_mode()
@contextlib.contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
yield
tracer._train_mode = mode
else:
yield
def _no_grad_(func):
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__
no_grad = wrap_decorator(_no_grad_)
@signature_safe_contextmanager
def guard(place=None):
train = framework.Program()
......
......@@ -22,6 +22,7 @@ import functools
from . import layers
from . import framework
from . import core
from .dygraph import base as imperative_base
__all__ = [
'GradClipByValue',
......@@ -37,6 +38,7 @@ class GradClipBase(object):
def _clip(self, para_and_grad):
raise NotImplementedError
@imperative_base.no_grad
def __call__(self, para_and_grad):
return self._clip(para_and_grad)
......@@ -86,6 +88,7 @@ class GradClipByValue(GradClipBase):
"""
@imperative_base.no_grad
def __init__(self, min_value, max_value=None):
if min_value is None:
......@@ -164,6 +167,7 @@ class GradClipByNorm(GradClipBase):
"""
@imperative_base.no_grad
def __init__(self, clip_norm):
self.clip_norm = clip_norm
......@@ -243,6 +247,7 @@ class GradClipByGlobalNorm(GradClipBase):
"""
@imperative_base.no_grad
def __init__(self, max_global_norm):
self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm)
......
......@@ -55,6 +55,7 @@ class Optimizer(object):
but need to use one of it's implementation.
"""
@imperative_base.no_grad
def __init__(self, learning_rate, regularization=None, name=None):
if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \
......@@ -472,6 +473,7 @@ class Optimizer(object):
optimize_ops = self.apply_gradients(params_grads)
return optimize_ops
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import unittest
class TestTracerMode(unittest.TestCase):
def setUp(self):
self.init_mode = True
def get_tracer_mode(self):
assert fluid.dygraph.enabled(), "Dygraph mode must be enabled"
@fluid.dygraph.no_grad
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
def test_main(self):
with fluid.dygraph.guard():
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = self.init_mode
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.tracer._train_mode, self.init_mode)
class TestTracerMode2(TestTracerMode):
def setUp(self):
self.init_mode = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册