未验证 提交 a1a95f81 编写于 作者: S songyouwei 提交者: GitHub

fix no_grad signature (#23600)

* fix no_grad signature
test=develop

* check func name instead of doc
test=develop
上级 f792d5f7
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import contextlib import contextlib
import functools
import sys import sys
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
...@@ -195,6 +196,7 @@ def no_grad(func=None): ...@@ -195,6 +196,7 @@ def no_grad(func=None):
return _switch_tracer_mode_guard_(is_train=False) return _switch_tracer_mode_guard_(is_train=False)
else: else:
@functools.wraps(func)
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs) return func(*args, **kwargs)
......
...@@ -49,6 +49,7 @@ class TestTracerMode(unittest.TestCase): ...@@ -49,6 +49,7 @@ class TestTracerMode(unittest.TestCase):
self.tracer._train_mode = self.init_mode self.tracer._train_mode = self.init_mode
self.assertEqual(self.no_grad_func(1), 1) self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
self.assertEqual(self.tracer._train_mode, self.init_mode) self.assertEqual(self.tracer._train_mode, self.init_mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册