From 8f63a3ecffff7d3734ff405c13a1df0f6fe0486c Mon Sep 17 00:00:00 2001 From: songyouwei Date: Tue, 14 Apr 2020 14:41:58 +0800 Subject: [PATCH] fix no_grad argspec (#23790) test=develop --- python/paddle/fluid/dygraph/base.py | 7 ++++--- .../fluid/tests/unittests/test_imperative_decorator.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 8c5c5bd791..97f6c64947 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator +import decorator import contextlib import functools import sys @@ -196,12 +197,12 @@ def no_grad(func=None): return _switch_tracer_mode_guard_(is_train=False) else: - @functools.wraps(func) - def __impl__(*args, **kwargs): + @decorator.decorator + def __impl__(func, *args, **kwargs): with _switch_tracer_mode_guard_(is_train=False): return func(*args, **kwargs) - return __impl__ + return __impl__(func) @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 2c1481c890..82e81d72f9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -15,6 +15,7 @@ import paddle.fluid as fluid import paddle.fluid.framework as framework import unittest +import inspect from test_imperative_base import new_program_scope @@ -51,6 +52,14 @@ class TestTracerMode(unittest.TestCase): self.assertEqual(self.no_grad_func(1), 1) self.assertEqual(self.no_grad_func.__name__, "no_grad_func") + def need_no_grad_func(a, b=1): + return a + b + + decorated_func = fluid.dygraph.no_grad(need_no_grad_func) + self.assertTrue( + str(inspect.getargspec(decorated_func)) == + str(inspect.getargspec(need_no_grad_func))) + self.assertEqual(self.tracer._train_mode, self.init_mode) with fluid.dygraph.guard(): -- GitLab