未验证 提交 1e5437de 编写于 作者: W WeiXin 提交者: GitHub

Fix test_jit_save_load random failure. (#34004)

* Fix test_jit_save_load random failure.

* Since CI is not activated, recommit the code.

* delete temp file.
上级 77a5b8b0
......@@ -1155,7 +1155,12 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadFunction(unittest.TestCase):
# NOTE(weixin): When there are multiple test functions in an
# `unittest.TestCase`, functions will affect each other,
# and there is a risk of random failure.
# So divided into three TestCase: TestJitSaveLoadFunctionCase1,
# TestJitSaveLoadFunctionCase2, TestJitSaveLoadFunctionCase3.
class TestJitSaveLoadFunctionCase1(unittest.TestCase):
def setUp(self):
paddle.disable_static()
......@@ -1174,6 +1179,11 @@ class TestJitSaveLoadFunction(unittest.TestCase):
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadFunctionCase2(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_function_input_spec(self):
@paddle.jit.to_static(input_spec=[
InputSpec(
......@@ -1191,6 +1201,11 @@ class TestJitSaveLoadFunction(unittest.TestCase):
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadFunctionCase3(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_function_function(self):
def fun(inputs):
return paddle.tanh(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册