未验证 提交 7586cdd5 编写于 作者: J Jiabin Yang 提交者: GitHub

Hide no support (#18515)

* test=develop, fix docker with paddle nccl problem

* test=develop, hide no_support api and add ut for it
上级 43e17c79
...@@ -21,7 +21,7 @@ import functools ...@@ -21,7 +21,7 @@ import functools
from . import layers from . import layers
from . import framework from . import framework
from . import core from . import core
from .dygraph import not_support from .dygraph.base import _not_support
__all__ = [ __all__ = [
'ErrorClipByValue', 'ErrorClipByValue',
...@@ -336,7 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -336,7 +336,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad return param, new_grad
@not_support @_not_support
def set_gradient_clip(clip, param_list=None, program=None): def set_gradient_clip(clip, param_list=None, program=None):
""" """
To specify parameters that require gradient clip. To specify parameters that require gradient clip.
......
...@@ -24,7 +24,6 @@ import logging ...@@ -24,7 +24,6 @@ import logging
__all__ = [ __all__ = [
'enabled', 'enabled',
'no_grad', 'no_grad',
'not_support',
'guard', 'guard',
'to_variable', 'to_variable',
] ]
...@@ -91,7 +90,7 @@ def _no_grad_(func): ...@@ -91,7 +90,7 @@ def _no_grad_(func):
no_grad = wrap_decorator(_no_grad_) no_grad = wrap_decorator(_no_grad_)
not_support = wrap_decorator(_dygraph_not_support_) _not_support = wrap_decorator(_dygraph_not_support_)
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import unittest import unittest
from test_imperative_base import new_program_scope
class TestTracerMode(unittest.TestCase): class TestTracerMode(unittest.TestCase):
...@@ -29,6 +30,18 @@ class TestTracerMode(unittest.TestCase): ...@@ -29,6 +30,18 @@ class TestTracerMode(unittest.TestCase):
self.assertEqual(self.tracer._train_mode, False) self.assertEqual(self.tracer._train_mode, False)
return a return a
@fluid.dygraph.base._not_support
def not_support_func(self):
return True
def check_not_support_rlt(self, ans):
try:
rlt = self.not_support_func()
except AssertionError:
rlt = False
finally:
self.assertEqual(rlt, ans)
def test_main(self): def test_main(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self.tracer = framework._dygraph_tracer() self.tracer = framework._dygraph_tracer()
...@@ -38,6 +51,12 @@ class TestTracerMode(unittest.TestCase): ...@@ -38,6 +51,12 @@ class TestTracerMode(unittest.TestCase):
self.assertEqual(self.tracer._train_mode, self.init_mode) self.assertEqual(self.tracer._train_mode, self.init_mode)
with fluid.dygraph.guard():
self.check_not_support_rlt(False)
with new_program_scope():
self.check_not_support_rlt(True)
class TestTracerMode2(TestTracerMode): class TestTracerMode2(TestTracerMode):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册