未验证 提交 75a17cdb 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Skip DoubleGrad-related unit tests under eager mode (#41380)

上级 5b8c5b7b
...@@ -17,6 +17,7 @@ import paddle.fluid as fluid ...@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import unittest import unittest
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
if fluid.is_compiled_with_cuda(): if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_cudnn_deterministic'] = True fluid.core.globals()['FLAGS_cudnn_deterministic'] = True
...@@ -583,7 +584,7 @@ class StaticGraphTrainModel(object): ...@@ -583,7 +584,7 @@ class StaticGraphTrainModel(object):
class TestStarGANWithGradientPenalty(unittest.TestCase): class TestStarGANWithGradientPenalty(unittest.TestCase):
def test_main(self): def func_main(self):
self.place_test(fluid.CPUPlace()) self.place_test(fluid.CPUPlace())
if fluid.is_compiled_with_cuda(): if fluid.is_compiled_with_cuda():
...@@ -615,6 +616,10 @@ class TestStarGANWithGradientPenalty(unittest.TestCase): ...@@ -615,6 +616,10 @@ class TestStarGANWithGradientPenalty(unittest.TestCase):
self.assertEqual(g_loss_s, g_loss_d) self.assertEqual(g_loss_s, g_loss_d)
self.assertEqual(d_loss_s, d_loss_d) self.assertEqual(d_loss_s, d_loss_d)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_main()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
......
...@@ -19,6 +19,7 @@ from paddle.vision.models import resnet50, resnet101 ...@@ -19,6 +19,7 @@ from paddle.vision.models import resnet50, resnet101
import unittest import unittest
from unittest import TestCase from unittest import TestCase
import numpy as np import numpy as np
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
def _dygraph_guard_(func): def _dygraph_guard_(func):
...@@ -65,7 +66,7 @@ class TestDygraphTripleGrad(TestCase): ...@@ -65,7 +66,7 @@ class TestDygraphTripleGrad(TestCase):
allow_unused=allow_unused) allow_unused=allow_unused)
@dygraph_guard @dygraph_guard
def test_exception(self): def func_exception(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.grad(None, None) self.grad(None, None)
...@@ -95,7 +96,7 @@ class TestDygraphTripleGrad(TestCase): ...@@ -95,7 +96,7 @@ class TestDygraphTripleGrad(TestCase):
self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1) self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1)
@dygraph_guard @dygraph_guard
def test_example_with_gradient_and_create_graph(self): def func_example_with_gradient_and_create_graph(self):
x = random_var(self.shape) x = random_var(self.shape)
x_np = x.numpy() x_np = x.numpy()
x.stop_gradient = False x.stop_gradient = False
...@@ -145,6 +146,11 @@ class TestDygraphTripleGrad(TestCase): ...@@ -145,6 +146,11 @@ class TestDygraphTripleGrad(TestCase):
dddx_grad_actual = x.gradient() dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_exception()
self.func_example_with_gradient_and_create_graph()
class TestDygraphTripleGradBradcastCase(TestCase): class TestDygraphTripleGradBradcastCase(TestCase):
def setUp(self): def setUp(self):
...@@ -172,7 +178,7 @@ class TestDygraphTripleGradBradcastCase(TestCase): ...@@ -172,7 +178,7 @@ class TestDygraphTripleGradBradcastCase(TestCase):
allow_unused=allow_unused) allow_unused=allow_unused)
@dygraph_guard @dygraph_guard
def test_example_with_gradient_and_create_graph(self): def func_example_with_gradient_and_create_graph(self):
x = random_var(self.x_shape) x = random_var(self.x_shape)
x_np = x.numpy() x_np = x.numpy()
x.stop_gradient = False x.stop_gradient = False
...@@ -227,6 +233,10 @@ class TestDygraphTripleGradBradcastCase(TestCase): ...@@ -227,6 +233,10 @@ class TestDygraphTripleGradBradcastCase(TestCase):
dddx_grad_actual = x.gradient() dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_example_with_gradient_and_create_graph()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册