From 7c73a68fee4b6c55c3faef7e04917e64c3818aed Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Mon, 8 Jul 2019 15:06:27 +0800 Subject: [PATCH] test=release/1.5, cherry-pick hide not_support for dygraph (#18528) * test=release/1.5, cherry-pick hide not_support for dygraph * test=release/1.5, cherry-pick hide not_support for dygraph --- python/paddle/fluid/clip.py | 4 ++-- python/paddle/fluid/dygraph/base.py | 3 +-- .../unittests/test_imperative_decorator.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 1c51ef296c6..09d7975a6b0 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -21,7 +21,7 @@ import functools from . import layers from . import framework from . import core -from .dygraph import not_support +from .dygraph.base import _not_support __all__ = [ 'ErrorClipByValue', @@ -336,7 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): return param, new_grad -@not_support +@_not_support def set_gradient_clip(clip, param_list=None, program=None): """ To specify parameters that require gradient clip. diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 133eb6a19c2..1c19fcb3eba 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -24,7 +24,6 @@ import logging __all__ = [ 'enabled', 'no_grad', - 'not_support', 'guard', 'to_variable', ] @@ -91,7 +90,7 @@ def _no_grad_(func): no_grad = wrap_decorator(_no_grad_) -not_support = wrap_decorator(_dygraph_not_support_) +_not_support = wrap_decorator(_dygraph_not_support_) @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 2b92bcb38be..c821a2e4bc8 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -15,6 +15,7 @@ import paddle.fluid as fluid import paddle.fluid.framework as framework import unittest +from test_imperative_base import new_program_scope class TestTracerMode(unittest.TestCase): @@ -29,6 +30,18 @@ class TestTracerMode(unittest.TestCase): self.assertEqual(self.tracer._train_mode, False) return a + @fluid.dygraph.base._not_support + def not_support_func(self): + return True + + def check_not_support_rlt(self, ans): + try: + rlt = self.not_support_func() + except AssertionError: + rlt = False + finally: + self.assertEqual(rlt, ans) + def test_main(self): with fluid.dygraph.guard(): self.tracer = framework._dygraph_tracer() @@ -38,6 +51,12 @@ class TestTracerMode(unittest.TestCase): self.assertEqual(self.tracer._train_mode, self.init_mode) + with fluid.dygraph.guard(): + self.check_not_support_rlt(False) + + with new_program_scope(): + self.check_not_support_rlt(True) + class TestTracerMode2(TestTracerMode): def setUp(self): -- GitLab