From 1e5437de3fdf89acb85324214ce40ace51f8f370 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 8 Jul 2021 10:36:23 +0800 Subject: [PATCH] Fix test_jit_save_load random failure. (#34004) * Fix test_jit_save_load random failure. * Since CI is not activated, recommit the code. * delete temp file. --- .../fluid/tests/unittests/test_jit_save_load.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index eef38182f6e..81db84a5262 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1155,7 +1155,12 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase): self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) -class TestJitSaveLoadFunction(unittest.TestCase): +# NOTE(weixin): When there are multiple test functions in an +# `unittest.TestCase`, functions will affect each other, +# and there is a risk of random failure. +# So divided into three TestCase: TestJitSaveLoadFunctionCase1, +# TestJitSaveLoadFunctionCase2, TestJitSaveLoadFunctionCase3. +class TestJitSaveLoadFunctionCase1(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -1174,6 +1179,11 @@ class TestJitSaveLoadFunction(unittest.TestCase): load_result = load_func(inps) self.assertTrue((load_result - origin).abs().max() < 1e-10) + +class TestJitSaveLoadFunctionCase2(unittest.TestCase): + def setUp(self): + paddle.disable_static() + def test_jit_save_load_function_input_spec(self): @paddle.jit.to_static(input_spec=[ InputSpec( @@ -1191,6 +1201,11 @@ class TestJitSaveLoadFunction(unittest.TestCase): load_result = load_func(inps) self.assertTrue((load_result - origin).abs().max() < 1e-10) + +class TestJitSaveLoadFunctionCase3(unittest.TestCase): + def setUp(self): + paddle.disable_static() + def test_jit_save_load_function_function(self): def fun(inputs): return paddle.tanh(inputs) -- GitLab