未验证 提交 75a17cdb 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Skip DoubleGrad-related unit tests under eager mode (#41380)

上级 5b8c5b7b
......@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np
import unittest
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_cudnn_deterministic'] = True
......@@ -583,7 +584,7 @@ class StaticGraphTrainModel(object):
class TestStarGANWithGradientPenalty(unittest.TestCase):
def test_main(self):
def func_main(self):
self.place_test(fluid.CPUPlace())
if fluid.is_compiled_with_cuda():
......@@ -615,6 +616,10 @@ class TestStarGANWithGradientPenalty(unittest.TestCase):
self.assertEqual(g_loss_s, g_loss_d)
self.assertEqual(d_loss_s, d_loss_d)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_main()
if __name__ == '__main__':
paddle.enable_static()
......
......@@ -19,6 +19,7 @@ from paddle.vision.models import resnet50, resnet101
import unittest
from unittest import TestCase
import numpy as np
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
def _dygraph_guard_(func):
......@@ -65,7 +66,7 @@ class TestDygraphTripleGrad(TestCase):
allow_unused=allow_unused)
@dygraph_guard
def test_exception(self):
def func_exception(self):
with self.assertRaises(AssertionError):
self.grad(None, None)
......@@ -95,7 +96,7 @@ class TestDygraphTripleGrad(TestCase):
self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1)
@dygraph_guard
def test_example_with_gradient_and_create_graph(self):
def func_example_with_gradient_and_create_graph(self):
x = random_var(self.shape)
x_np = x.numpy()
x.stop_gradient = False
......@@ -145,6 +146,11 @@ class TestDygraphTripleGrad(TestCase):
dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_exception()
self.func_example_with_gradient_and_create_graph()
class TestDygraphTripleGradBradcastCase(TestCase):
def setUp(self):
......@@ -172,7 +178,7 @@ class TestDygraphTripleGradBradcastCase(TestCase):
allow_unused=allow_unused)
@dygraph_guard
def test_example_with_gradient_and_create_graph(self):
def func_example_with_gradient_and_create_graph(self):
x = random_var(self.x_shape)
x_np = x.numpy()
x.stop_gradient = False
......@@ -227,6 +233,10 @@ class TestDygraphTripleGradBradcastCase(TestCase):
dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_example_with_gradient_and_create_graph()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册