未验证 提交 75a17cdb 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Skip DoubleGrad-related unit tests under eager mode (#41380)

上级 5b8c5b7b
......@@ -21,6 +21,7 @@ import paddle
import paddle.compat as cpt
import paddle.nn.functional as F
from paddle.autograd.functional import _as_tensors
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
import config
import utils
......@@ -145,7 +146,7 @@ class TestAutogradFunctional(unittest.TestCase):
class TestVJP(TestAutogradFunctional):
def test_vjp_i1o1(self):
def func_vjp_i1o1(self):
test_cases = [
[reduce, 'A'], # noqa
[reduce_dim, 'A'], # noqa
......@@ -155,7 +156,7 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o1(self):
def func_vjp_i2o1(self):
test_cases = [
[matmul, ['A', 'B']], # noqa
[mul, ['b', 'c']], # noqa
......@@ -165,7 +166,7 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2(self):
def func_vjp_i2o2(self):
test_cases = [
[o2, ['A', 'A']], # noqa
] # noqa
......@@ -176,7 +177,7 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_omitting_v(self):
def func_vjp_i2o2_omitting_v(self):
test_cases = [
[o2, ['A', 'A']], # noqa
] # noqa
......@@ -186,7 +187,7 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_nested(self):
def func_vjp_nested(self):
x = self.gen_input('a')
test_cases = [
[nested(x), 'a'], # noqa
......@@ -196,13 +197,22 @@ class TestVJP(TestAutogradFunctional):
vjp_result, grad_result = vjp(), grad()
self.check_results(grad_result, vjp_result)
def test_vjp_aliased_input(self):
def func_vjp_aliased_input(self):
x = self.gen_input('a')
ref = self.gen_test_pairs(nested(x), 'a')[0]
aliased = self.gen_test_pairs(nested(x), x)[0]
ref_result, aliased_result = ref(), aliased()
self.check_results(ref_result, aliased_result)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_vjp_i1o1()
self.func_vjp_i2o1()
self.func_vjp_i2o2()
self.func_vjp_i2o2_omitting_v()
self.func_vjp_nested()
self.func_vjp_aliased_input()
@utils.place(config.DEVICES)
@utils.parameterize(
......@@ -210,12 +220,16 @@ class TestVJP(TestAutogradFunctional):
('v_shape_not_equal_ys', utils.square, np.random.rand(3),
np.random.rand(1), RuntimeError), ))
class TestVJPException(unittest.TestCase):
def test_vjp(self):
def func_vjp(self):
with self.assertRaises(self.expected_exception):
paddle.autograd.vjp(self.fun,
paddle.to_tensor(self.xs),
paddle.to_tensor(self.v))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_vjp()
def jac(grad_fn, f, inputs):
assert grad_fn in [paddle.autograd.vjp, paddle.autograd.jvp]
......@@ -246,7 +260,7 @@ def jac(grad_fn, f, inputs):
class TestJVP(TestAutogradFunctional):
def test_jvp_i1o1(self):
def func_jvp_i1o1(self):
test_cases = [
[reduce, 'A'], # noqa
[reduce_dim, 'A'], # noqa
......@@ -257,7 +271,7 @@ class TestJVP(TestAutogradFunctional):
reverse_jac = jac(paddle.autograd.vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o1(self):
def func_jvp_i2o1(self):
test_cases = [ # noqa
[matmul, ['A', 'B']], # noqa
] # noqa
......@@ -267,7 +281,7 @@ class TestJVP(TestAutogradFunctional):
reverse_jac = jac(paddle.autograd.vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2(self):
def func_jvp_i2o2(self):
test_cases = [ # noqa
[o2, ['A', 'A']], # noqa
] # noqa
......@@ -277,7 +291,7 @@ class TestJVP(TestAutogradFunctional):
reverse_jac = jac(paddle.autograd.vjp, f, inputs)
self.check_results(forward_jac, reverse_jac)
def test_jvp_i2o2_omitting_v(self):
def func_jvp_i2o2_omitting_v(self):
test_cases = [ # noqa
[o2, ['A', 'A']], # noqa
] # noqa
......@@ -288,6 +302,13 @@ class TestJVP(TestAutogradFunctional):
results_with_v = paddle.autograd.jvp(f, inputs, v)
self.check_results(results_omitting_v, results_with_v)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_jvp_i1o1()
self.func_jvp_i2o1()
self.func_jvp_i2o2()
self.func_jvp_i2o2_omitting_v()
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), (
......@@ -312,7 +333,7 @@ class TestJacobianClassNoBatch(unittest.TestCase):
self._actual = paddle.autograd.Jacobian(self.func, self.xs, False)
self._expected = self._expected()
def test_jacobian(self):
def func_jacobian(self):
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
......@@ -333,6 +354,10 @@ class TestJacobianClassNoBatch(unittest.TestCase):
self._dtype)
return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_jacobian()
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), (
......@@ -355,7 +380,7 @@ class TestJacobianClassBatchFirst(unittest.TestCase):
self._actual = paddle.autograd.Jacobian(self.func, self.xs, True)
self._expected = self._expected()
def test_jacobian(self):
def func_jacobian(self):
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None),
......@@ -384,6 +409,10 @@ class TestJacobianClassBatchFirst(unittest.TestCase):
return utils._np_transpose_matrix_format(jac, utils.MatrixFormat.NBM,
utils.MatrixFormat.BNM)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_jacobian()
class TestHessianClassNoBatch(unittest.TestCase):
@classmethod
......@@ -400,7 +429,7 @@ class TestHessianClassNoBatch(unittest.TestCase):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
......@@ -413,7 +442,7 @@ class TestHessianClassNoBatch(unittest.TestCase):
np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian,
self.rtol, self.atol)
def test_multi_input(self):
def func_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
......@@ -429,7 +458,7 @@ class TestHessianClassNoBatch(unittest.TestCase):
rtol=self.rtol,
atol=self.atol)
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
......@@ -442,7 +471,7 @@ class TestHessianClassNoBatch(unittest.TestCase):
np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian,
self.rtol, self.atol)
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
......@@ -455,13 +484,21 @@ class TestHessianClassNoBatch(unittest.TestCase):
np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian,
self.rtol, self.atol)
def test_out_not_single(self):
def func_out_not_single(self):
def func(x):
return x * x
with self.assertRaises(RuntimeError):
paddle.autograd.Hessian(func, paddle.ones([3]))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_single_input()
self.func_multi_input()
self.func_allow_unused_true()
self.func_create_graph_true()
self.func_out_not_single()
class TestHessianClassBatchFirst(unittest.TestCase):
@classmethod
......@@ -482,7 +519,7 @@ class TestHessianClassBatchFirst(unittest.TestCase):
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_single_input(self):
def func_single_input(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -496,7 +533,7 @@ class TestHessianClassBatchFirst(unittest.TestCase):
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def test_multi_input(self):
def func_multi_input(self):
def func(x, y):
return paddle.matmul(x * x * y * y, self.weight)[:, 0:1]
......@@ -517,7 +554,7 @@ class TestHessianClassBatchFirst(unittest.TestCase):
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def test_allow_unused(self):
def func_allow_unused(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -538,7 +575,7 @@ class TestHessianClassBatchFirst(unittest.TestCase):
np.testing.assert_allclose(
actual, expected, rtol=self.rtol, atol=self.atol)
def test_stop_gradient(self):
def func_stop_gradient(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -554,13 +591,21 @@ class TestHessianClassBatchFirst(unittest.TestCase):
np.testing.assert_allclose(actual, expected, self.rtol, self.atol)
def test_out_not_single(self):
def func_out_not_single(self):
def func(x):
return (x * x)
with self.assertRaises(RuntimeError):
paddle.autograd.Hessian(func, paddle.ones((3, 3)), is_batched=True)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_single_input()
self.func_multi_input()
self.func_allow_unused()
self.func_stop_gradient()
self.func_out_not_single()
class TestHessian(unittest.TestCase):
@classmethod
......@@ -577,7 +622,7 @@ class TestHessian(unittest.TestCase):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
......@@ -589,7 +634,7 @@ class TestHessian(unittest.TestCase):
np.testing.assert_allclose(hessian.numpy(), numerical_hessian[0][0],
self.rtol, self.atol)
def test_multi_input(self):
def func_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
......@@ -605,7 +650,7 @@ class TestHessian(unittest.TestCase):
numerical_hessian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func_allow_unused_false(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
......@@ -617,7 +662,7 @@ class TestHessian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
......@@ -636,7 +681,7 @@ class TestHessian(unittest.TestCase):
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func_create_graph_false(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
......@@ -653,7 +698,7 @@ class TestHessian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
......@@ -667,6 +712,15 @@ class TestHessian(unittest.TestCase):
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_single_input()
self.func_multi_input()
self.func_allow_unused_false()
self.func_allow_unused_true()
self.func_create_graph_false()
self.func_create_graph_true()
class TestHessianFloat64(TestHessian):
@classmethod
......@@ -702,7 +756,7 @@ class TestBatchHessian(unittest.TestCase):
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_single_input(self):
def func_single_input(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -713,7 +767,7 @@ class TestBatchHessian(unittest.TestCase):
np.testing.assert_allclose(hessian, numerical_hessian, self.rtol,
self.atol)
def test_multi_input(self):
def func_multi_input(self):
def func(x, y):
return paddle.matmul(x * x * y * y, self.weight)[:, 0:1]
......@@ -729,7 +783,7 @@ class TestBatchHessian(unittest.TestCase):
np.testing.assert_allclose(hessian_reshape, numerical_hessian,
self.rtol, self.atol)
def test_allow_unused_false(self):
def func_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -741,7 +795,7 @@ class TestBatchHessian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -763,7 +817,7 @@ class TestBatchHessian(unittest.TestCase):
else:
assert hessian[i][j] is None
def test_create_graph_false(self):
def func_create_graph_false(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -780,7 +834,7 @@ class TestBatchHessian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x):
return paddle.matmul(x * x, self.weight)[:, 0:1]
......@@ -794,6 +848,15 @@ class TestBatchHessian(unittest.TestCase):
triple_grad = paddle.grad(hessian, self.x)
assert triple_grad is not None
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_single_input()
self.func_multi_input()
self.func_allow_unused_false()
self.func_allow_unused_true()
self.func_create_graph_false()
self.func_create_graph_true()
class TestBatchHessianFloat64(TestBatchHessian):
@classmethod
......@@ -831,7 +894,7 @@ class TestVHP(unittest.TestCase):
self.vx = paddle.rand(shape=self.shape, dtype=self.dtype)
self.vy = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input(self):
def func_single_input(self):
def func(x):
return paddle.sum(paddle.matmul(x, x))
......@@ -846,7 +909,7 @@ class TestVHP(unittest.TestCase):
np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
def test_multi_input(self):
def func_multi_input(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
......@@ -865,7 +928,7 @@ class TestVHP(unittest.TestCase):
np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i],
self.rtol, self.atol)
def test_v_default(self):
def func_v_default(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, y))
......@@ -885,7 +948,7 @@ class TestVHP(unittest.TestCase):
np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i],
self.rtol, self.atol)
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return paddle.sum(paddle.matmul(x, x))
......@@ -903,7 +966,7 @@ class TestVHP(unittest.TestCase):
np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol,
self.atol)
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x):
return paddle.sum(F.sigmoid(x))
......@@ -921,6 +984,14 @@ class TestVHP(unittest.TestCase):
triple_grad = paddle.grad(vhp, self.x)
assert triple_grad is not None
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_v_default()
self.func_multi_input()
self.func_single_input()
self.func_allow_unused_true()
self.func_create_graph_true()
class TestJacobian(unittest.TestCase):
@classmethod
......@@ -934,7 +1005,7 @@ class TestJacobian(unittest.TestCase):
self.x = paddle.rand(shape=self.shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.shape, dtype=self.dtype)
def test_single_input_and_single_output(self):
def func_single_input_and_single_output(self):
def func(x):
return paddle.matmul(x, x)
......@@ -945,7 +1016,7 @@ class TestJacobian(unittest.TestCase):
np.testing.assert_allclose(jacobian.numpy(), numerical_jacobian[0][0],
self.rtol, self.atol)
def test_single_input_and_multi_output(self):
def func_single_input_and_multi_output(self):
def func(x):
return paddle.matmul(x, x), x * x
......@@ -958,7 +1029,7 @@ class TestJacobian(unittest.TestCase):
numerical_jacobian[i][0], self.rtol,
self.atol)
def test_multi_input_and_single_output(self):
def func_multi_input_and_single_output(self):
def func(x, y):
return paddle.matmul(x, y)
......@@ -972,7 +1043,7 @@ class TestJacobian(unittest.TestCase):
numerical_jacobian[0][j], self.rtol,
self.atol)
def test_multi_input_and_multi_output(self):
def func_multi_input_and_multi_output(self):
def func(x, y):
return paddle.matmul(x, y), x * y
......@@ -987,7 +1058,7 @@ class TestJacobian(unittest.TestCase):
numerical_jacobian[i][j], self.rtol,
self.atol)
def test_allow_unused_false(self):
def func_allow_unused_false(self):
def func(x, y):
return paddle.matmul(x, x)
......@@ -999,7 +1070,7 @@ class TestJacobian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return paddle.matmul(x, x)
......@@ -1013,7 +1084,7 @@ class TestJacobian(unittest.TestCase):
jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func_create_graph_false(self):
def func(x, y):
return paddle.matmul(x, y)
......@@ -1033,7 +1104,7 @@ class TestJacobian(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x, y):
return paddle.matmul(x, y)
......@@ -1051,6 +1122,17 @@ class TestJacobian(unittest.TestCase):
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_multi_input_and_multi_output()
self.func_multi_input_and_single_output()
self.func_single_input_and_multi_output()
self.func_single_input_and_single_output()
self.func_allow_unused_false()
self.func_allow_unused_true()
self.func_create_graph_false()
self.func_create_graph_true()
class TestJacobianFloat64(TestJacobian):
@classmethod
......@@ -1080,7 +1162,7 @@ class TestJacobianBatch(unittest.TestCase):
self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype)
self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype)
def test_batch_single_input_and_batch_single_output(self):
def func_batch_single_input_and_batch_single_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y)
......@@ -1096,7 +1178,7 @@ class TestJacobianBatch(unittest.TestCase):
np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0]
.all()))
def test_batch_single_input_and_batch_multi_output(self):
def func_batch_single_input_and_batch_multi_output(self):
def func(x):
return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x
......@@ -1113,7 +1195,7 @@ class TestJacobianBatch(unittest.TestCase):
numerical_jacobian[i][0], self.rtol,
self.atol)
def test_batch_multi_input_and_batch_single_output(self):
def func_batch_multi_input_and_batch_single_output(self):
def func(x, y):
return x * y
......@@ -1129,7 +1211,7 @@ class TestJacobianBatch(unittest.TestCase):
numerical_jacobian[0][j], self.rtol,
self.atol)
def test_batch_multi_input_and_batch_multi_output(self):
def func_batch_multi_input_and_batch_multi_output(self):
def func(x, y):
return x * y, x * y
......@@ -1144,7 +1226,7 @@ class TestJacobianBatch(unittest.TestCase):
np.testing.assert_allclose(batch_jacobian[i], numerical_jacobian[i],
self.rtol, self.atol)
def test_allow_unused_false(self):
def func_allow_unused_false(self):
def func(x, y):
return x * x
......@@ -1156,7 +1238,7 @@ class TestJacobianBatch(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("allow_unused") > 0
def test_allow_unused_true(self):
def func_allow_unused_true(self):
def func(x, y):
return x * x
......@@ -1171,7 +1253,7 @@ class TestJacobianBatch(unittest.TestCase):
jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol)
assert jacobian[1] is None
def test_create_graph_false(self):
def func_create_graph_false(self):
def func(x, y):
return x * y
......@@ -1191,7 +1273,7 @@ class TestJacobianBatch(unittest.TestCase):
error_msg = cpt.get_exception_message(e)
assert error_msg.find("has no gradient") > 0
def test_create_graph_true(self):
def func_create_graph_true(self):
def func(x, y):
return x * y
......@@ -1209,6 +1291,17 @@ class TestJacobianBatch(unittest.TestCase):
double_grad = paddle.grad(jacobian[0], [self.x, self.y])
assert double_grad is not None
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_batch_single_input_and_batch_single_output()
self.func_batch_single_input_and_batch_multi_output()
self.func_batch_multi_input_and_batch_single_output()
self.func_batch_multi_input_and_batch_multi_output()
self.func_allow_unused_false()
self.func_allow_unused_true()
self.func_create_graph_false()
self.func_create_graph_true()
class TestJacobianBatchFloat64(TestJacobianBatch):
@classmethod
......
......@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import numpy as np
import unittest
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_cudnn_deterministic'] = True
......@@ -583,7 +584,7 @@ class StaticGraphTrainModel(object):
class TestStarGANWithGradientPenalty(unittest.TestCase):
def test_main(self):
def func_main(self):
self.place_test(fluid.CPUPlace())
if fluid.is_compiled_with_cuda():
......@@ -615,6 +616,10 @@ class TestStarGANWithGradientPenalty(unittest.TestCase):
self.assertEqual(g_loss_s, g_loss_d)
self.assertEqual(d_loss_s, d_loss_d)
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_main()
if __name__ == '__main__':
paddle.enable_static()
......
......@@ -19,6 +19,7 @@ from paddle.vision.models import resnet50, resnet101
import unittest
from unittest import TestCase
import numpy as np
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
def _dygraph_guard_(func):
......@@ -65,7 +66,7 @@ class TestDygraphTripleGrad(TestCase):
allow_unused=allow_unused)
@dygraph_guard
def test_exception(self):
def func_exception(self):
with self.assertRaises(AssertionError):
self.grad(None, None)
......@@ -95,7 +96,7 @@ class TestDygraphTripleGrad(TestCase):
self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1)
@dygraph_guard
def test_example_with_gradient_and_create_graph(self):
def func_example_with_gradient_and_create_graph(self):
x = random_var(self.shape)
x_np = x.numpy()
x.stop_gradient = False
......@@ -145,6 +146,11 @@ class TestDygraphTripleGrad(TestCase):
dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_exception()
self.func_example_with_gradient_and_create_graph()
class TestDygraphTripleGradBradcastCase(TestCase):
def setUp(self):
......@@ -172,7 +178,7 @@ class TestDygraphTripleGradBradcastCase(TestCase):
allow_unused=allow_unused)
@dygraph_guard
def test_example_with_gradient_and_create_graph(self):
def func_example_with_gradient_and_create_graph(self):
x = random_var(self.x_shape)
x_np = x.numpy()
x.stop_gradient = False
......@@ -227,6 +233,10 @@ class TestDygraphTripleGradBradcastCase(TestCase):
dddx_grad_actual = x.gradient()
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self):
if _in_legacy_dygraph():
self.func_example_with_gradient_and_create_graph()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册