diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 1c51ef296c6a2d0d1cc21bf55187c1f0722570ff..09d7975a6b0d5dc69bd14a127a64f068fbf1f910 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -21,7 +21,7 @@ import functools from . import layers from . import framework from . import core -from .dygraph import not_support +from .dygraph.base import _not_support __all__ = [ 'ErrorClipByValue', @@ -336,7 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): return param, new_grad -@not_support +@_not_support def set_gradient_clip(clip, param_list=None, program=None): """ To specify parameters that require gradient clip. diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 133eb6a19c2e20287ef6588cc2c4f780ec7dbdd4..1c19fcb3eba8cceabcb8ddb4337b93d2dbd803c6 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -24,7 +24,6 @@ import logging __all__ = [ 'enabled', 'no_grad', - 'not_support', 'guard', 'to_variable', ] @@ -91,7 +90,7 @@ def _no_grad_(func): no_grad = wrap_decorator(_no_grad_) -not_support = wrap_decorator(_dygraph_not_support_) +_not_support = wrap_decorator(_dygraph_not_support_) @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 2b92bcb38befcd7d1d2932e2681dc167fa6777ee..c821a2e4bc80bba3a3f63d3d8016fd957021ae7e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -15,6 +15,7 @@ import paddle.fluid as fluid import paddle.fluid.framework as framework import unittest +from test_imperative_base import new_program_scope class TestTracerMode(unittest.TestCase): @@ -29,6 +30,18 @@ class TestTracerMode(unittest.TestCase): self.assertEqual(self.tracer._train_mode, False) return a + @fluid.dygraph.base._not_support + def not_support_func(self): + return True + + def check_not_support_rlt(self, ans): + try: + rlt = self.not_support_func() + except AssertionError: + rlt = False + finally: + self.assertEqual(rlt, ans) + def test_main(self): with fluid.dygraph.guard(): self.tracer = framework._dygraph_tracer() @@ -38,6 +51,12 @@ class TestTracerMode(unittest.TestCase): self.assertEqual(self.tracer._train_mode, self.init_mode) + with fluid.dygraph.guard(): + self.check_not_support_rlt(False) + + with new_program_scope(): + self.check_not_support_rlt(True) + class TestTracerMode2(TestTracerMode): def setUp(self):