未验证 提交 89ef2915 编写于 作者: Y Yang Zhang 提交者: GitHub

Revert `no_grad` changes and add new implementation (#26902)

上级 dd28cada
...@@ -259,7 +259,7 @@ from .device import get_device ...@@ -259,7 +259,7 @@ from .device import get_device
from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS
from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS
from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad #DEFINE_ALIAS from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS
from . import jit from . import jit
from . import static from . import static
......
...@@ -129,7 +129,7 @@ class GradientClipBase(object): ...@@ -129,7 +129,7 @@ class GradientClipBase(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
@imperative_base.no_grad() @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
raise NotImplementedError raise NotImplementedError
...@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase): ...@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max) return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad() @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
for p, g in params_grads: for p, g in params_grads:
...@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase): ...@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad() @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
for p, g in params_grads: for p, g in params_grads:
...@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
def __str__(self): def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad() @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import inspect
import decorator import decorator
import contextlib import contextlib
import functools
import inspect
import sys import sys
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
...@@ -26,8 +27,8 @@ import objgraph ...@@ -26,8 +27,8 @@ import objgraph
from ..data_feeder import convert_dtype from ..data_feeder import convert_dtype
__all__ = [ __all__ = [
'no_grad', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', 'enabled', 'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
'to_variable' 'enabled', 'to_variable'
] ]
...@@ -167,7 +168,80 @@ def disable_dygraph(): ...@@ -167,7 +168,80 @@ def disable_dygraph():
_functional_dygraph_context_manager = None _functional_dygraph_context_manager = None
class no_grad: @signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield
def no_grad(func=None):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.)
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__(func)
class no_grad_:
""" """
:api_attr: imperative :api_attr: imperative
......
...@@ -41,7 +41,7 @@ def monkey_patch_math_varbase(): ...@@ -41,7 +41,7 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance. The difference is, in dygraph mode, use auto-generated op functions for better performance.
""" """
@no_grad() @no_grad
def create_tensor(value, dtype, shape): def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape, out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
......
...@@ -445,7 +445,7 @@ class DataParallel(layers.Layer): ...@@ -445,7 +445,7 @@ class DataParallel(layers.Layer):
self._reshape_inplace(x=g_var, shape=g_shape) self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape assert g_var.shape == g_shape
@no_grad() @no_grad
def apply_collective_grads(self): def apply_collective_grads(self):
""" """
AllReduce the Parameters' gradient. AllReduce the Parameters' gradient.
......
...@@ -61,7 +61,7 @@ class Optimizer(object): ...@@ -61,7 +61,7 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
@imperative_base.no_grad() @imperative_base.no_grad
def __init__(self, def __init__(self,
learning_rate, learning_rate,
parameter_list=None, parameter_list=None,
...@@ -897,7 +897,7 @@ class Optimizer(object): ...@@ -897,7 +897,7 @@ class Optimizer(object):
if p.trainable: if p.trainable:
p.clear_gradient() p.clear_gradient()
@imperative_base.no_grad() @imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -1015,7 +1015,7 @@ class SGDOptimizer(Optimizer): ...@@ -1015,7 +1015,7 @@ class SGDOptimizer(Optimizer):
name=name) name=name)
self.type = "sgd" self.type = "sgd"
@no_grad() @no_grad
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
...@@ -1552,7 +1552,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -1552,7 +1552,7 @@ class DGCMomentumOptimizer(Optimizer):
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(), dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name]) [param_var.name, grad_var.name])
@imperative_base.no_grad() @imperative_base.no_grad
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \ params_grads, table_param_and_grad, table_optimize_op = \
......
...@@ -28,7 +28,7 @@ class TestTracerMode(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestTracerMode(unittest.TestCase):
def get_tracer_mode(self): def get_tracer_mode(self):
assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled" assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled"
@paddle.no_grad() @fluid.dygraph.no_grad
def no_grad_func(self, a): def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False) self.assertEqual(self.tracer._train_mode, False)
return a return a
...@@ -56,35 +56,17 @@ class TestTracerMode(unittest.TestCase): ...@@ -56,35 +56,17 @@ class TestTracerMode(unittest.TestCase):
def need_no_grad_func(a, b=1): def need_no_grad_func(a, b=1):
return a + b return a + b
decorated_func = paddle.no_grad()(need_no_grad_func) decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue( self.assertTrue(
str(inspect.getargspec(decorated_func)) == str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func))) str(inspect.getargspec(need_no_grad_func)))
self.assertEqual(self.tracer._train_mode, self.init_mode) self.assertEqual(self.tracer._train_mode, self.init_mode)
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self.check_not_support_rlt(False) self.check_not_support_rlt(False)
paddle.enable_static()
with new_program_scope(): with new_program_scope():
self.check_not_support_rlt(True) self.check_not_support_rlt(True)
...@@ -94,5 +76,48 @@ class TestTracerMode2(TestTracerMode): ...@@ -94,5 +76,48 @@ class TestTracerMode2(TestTracerMode):
self.init_mode = False self.init_mode = False
class TestNoGradClass(unittest.TestCase):
@paddle.no_grad()
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
def test_main(self):
paddle.disable_static()
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
def need_no_grad_func(a, b=1):
return a + b
decorated_func = paddle.no_grad()(need_no_grad_func)
self.assertEqual(
str(inspect.getargspec(decorated_func)),
str(inspect.getargspec(need_no_grad_func)))
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -98,7 +98,7 @@ class Optimizer(object): ...@@ -98,7 +98,7 @@ class Optimizer(object):
""" """
@imperative_base.no_grad() @imperative_base.no_grad
def __init__(self, def __init__(self,
learning_rate, learning_rate,
parameters=None, parameters=None,
...@@ -812,7 +812,7 @@ class Optimizer(object): ...@@ -812,7 +812,7 @@ class Optimizer(object):
if p.trainable: if p.trainable:
p.clear_gradient() p.clear_gradient()
@imperative_base.no_grad() @imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
......
...@@ -85,7 +85,7 @@ class SGD(Optimizer): ...@@ -85,7 +85,7 @@ class SGD(Optimizer):
name=name) name=name)
self.type = "sgd" self.type = "sgd"
@no_grad() @no_grad
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册