未验证 提交 6129b0e2 编写于 作者: Y Yang Zhang 提交者: GitHub

Revert `no_grad` changes and add new implementation (#26826)

上级 d067e66d
......@@ -259,7 +259,7 @@ from .device import get_device
from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS
from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS
from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS
from . import jit
from . import static
......
......@@ -129,7 +129,7 @@ class GradientClipBase(object):
def __str__(self):
raise NotImplementedError()
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
raise NotImplementedError
......@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase):
def __str__(self):
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
......@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase):
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
......@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
......
......@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import inspect
import decorator
import contextlib
import functools
import inspect
import sys
import numpy as np
from paddle.fluid import core
......@@ -26,8 +27,8 @@ import objgraph
from ..data_feeder import convert_dtype
__all__ = [
'no_grad', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', 'enabled',
'to_variable'
'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
'enabled', 'to_variable'
]
......@@ -167,7 +168,80 @@ def disable_dygraph():
_functional_dygraph_context_manager = None
class no_grad:
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield
def no_grad(func=None):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.)
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__(func)
class no_grad_:
"""
:api_attr: imperative
......
......@@ -41,7 +41,7 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance.
"""
@no_grad()
@no_grad
def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
......
......@@ -445,7 +445,7 @@ class DataParallel(layers.Layer):
self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape
@no_grad()
@no_grad
def apply_collective_grads(self):
"""
AllReduce the Parameters' gradient.
......
......@@ -61,7 +61,7 @@ class Optimizer(object):
but need to use one of it's implementation.
"""
@imperative_base.no_grad()
@imperative_base.no_grad
def __init__(self,
learning_rate,
parameter_list=None,
......@@ -897,7 +897,7 @@ class Optimizer(object):
if p.trainable:
p.clear_gradient()
@imperative_base.no_grad()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
......@@ -1015,7 +1015,7 @@ class SGDOptimizer(Optimizer):
name=name)
self.type = "sgd"
@no_grad()
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
......@@ -1552,7 +1552,7 @@ class DGCMomentumOptimizer(Optimizer):
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])
@imperative_base.no_grad()
@imperative_base.no_grad
def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
......
......@@ -28,7 +28,7 @@ class TestTracerMode(unittest.TestCase):
def get_tracer_mode(self):
assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled"
@paddle.no_grad()
@fluid.dygraph.no_grad
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
......@@ -56,35 +56,17 @@ class TestTracerMode(unittest.TestCase):
def need_no_grad_func(a, b=1):
return a + b
decorated_func = paddle.no_grad()(need_no_grad_func)
decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue(
str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func)))
self.assertEqual(self.tracer._train_mode, self.init_mode)
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
with fluid.dygraph.guard():
self.check_not_support_rlt(False)
paddle.enable_static()
with new_program_scope():
self.check_not_support_rlt(True)
......@@ -94,5 +76,48 @@ class TestTracerMode2(TestTracerMode):
self.init_mode = False
class TestNoGradClass(unittest.TestCase):
@paddle.no_grad()
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
def test_main(self):
paddle.disable_static()
self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
def need_no_grad_func(a, b=1):
return a + b
decorated_func = paddle.no_grad()(need_no_grad_func)
self.assertEqual(
str(inspect.getargspec(decorated_func)),
str(inspect.getargspec(need_no_grad_func)))
def test_gen():
for i in range(3):
yield i
a = 0
for i in test_gen():
a += i
@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i
b = 0
for i in test_wrapped_gen():
b += i
self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()
......@@ -97,7 +97,7 @@ class Optimizer(object):
"""
@imperative_base.no_grad()
@imperative_base.no_grad
def __init__(self,
learning_rate,
parameters=None,
......@@ -815,7 +815,7 @@ class Optimizer(object):
if p.trainable:
p.clear_gradient()
@imperative_base.no_grad()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
......
......@@ -85,7 +85,7 @@ class SGD(Optimizer):
name=name)
self.type = "sgd"
@no_grad()
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册