未验证 提交 8f63a3ec 编写于 作者: S songyouwei 提交者: GitHub

fix no_grad argspec (#23790)

test=develop
上级 9549b786
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator
import contextlib import contextlib
import functools import functools
import sys import sys
...@@ -196,12 +197,12 @@ def no_grad(func=None): ...@@ -196,12 +197,12 @@ def no_grad(func=None):
return _switch_tracer_mode_guard_(is_train=False) return _switch_tracer_mode_guard_(is_train=False)
else: else:
@functools.wraps(func) @decorator.decorator
def __impl__(*args, **kwargs): def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__(func)
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import unittest import unittest
import inspect
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
...@@ -51,6 +52,14 @@ class TestTracerMode(unittest.TestCase): ...@@ -51,6 +52,14 @@ class TestTracerMode(unittest.TestCase):
self.assertEqual(self.no_grad_func(1), 1) self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func") self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
def need_no_grad_func(a, b=1):
return a + b
decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue(
str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func)))
self.assertEqual(self.tracer._train_mode, self.init_mode) self.assertEqual(self.tracer._train_mode, self.init_mode)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册