diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index eef38182f6edf69280b0eafd8e3d0794dc0e5f12..81db84a5262fb65914ab9a32688208a3f50cbc62 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1155,7 +1155,12 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase): self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) -class TestJitSaveLoadFunction(unittest.TestCase): +# NOTE(weixin): When there are multiple test functions in an +# `unittest.TestCase`, functions will affect each other, +# and there is a risk of random failure. +# So divided into three TestCase: TestJitSaveLoadFunctionCase1, +# TestJitSaveLoadFunctionCase2, TestJitSaveLoadFunctionCase3. +class TestJitSaveLoadFunctionCase1(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -1174,6 +1179,11 @@ class TestJitSaveLoadFunction(unittest.TestCase): load_result = load_func(inps) self.assertTrue((load_result - origin).abs().max() < 1e-10) + +class TestJitSaveLoadFunctionCase2(unittest.TestCase): + def setUp(self): + paddle.disable_static() + def test_jit_save_load_function_input_spec(self): @paddle.jit.to_static(input_spec=[ InputSpec( @@ -1191,6 +1201,11 @@ class TestJitSaveLoadFunction(unittest.TestCase): load_result = load_func(inps) self.assertTrue((load_result - origin).abs().max() < 1e-10) + +class TestJitSaveLoadFunctionCase3(unittest.TestCase): + def setUp(self): + paddle.disable_static() + def test_jit_save_load_function_function(self): def fun(inputs): return paddle.tanh(inputs)