未验证 提交 41a7ce83 编写于 作者: L Leo Chen 提交者: GitHub

fix random failure of test_buffer_sahred_memory_reuse_pass (#27551)

上级 a2e0b7cb
...@@ -115,8 +115,15 @@ class InplaceTestBase(unittest.TestCase): ...@@ -115,8 +115,15 @@ class InplaceTestBase(unittest.TestCase):
fetch_val2, = exe.run(compiled_prog, fetch_val2, = exe.run(compiled_prog,
feed=feed_dict, feed=feed_dict,
fetch_list=[fetch_var]) fetch_list=[fetch_var])
#NOTE(zhiqiu): Temporally changed from array_equal to allclose.
self.assertTrue(np.array_equal(fetch_val1, fetch_val2)) # The real root is fuse_all_reduce and fuse_all_optimizer_opss may
# result in diff because of the instruction set on the virtual machine.
# And the related unit tests: test_fuse_all_reduce_pass and test_fuse_optimizer_pass use "almostEqual" in their checks.
# There are also some related issues:
# https://github.com/PaddlePaddle/Paddle/issues/21270
# https://github.com/PaddlePaddle/Paddle/issues/21046
# https://github.com/PaddlePaddle/Paddle/issues/21045
self.assertTrue(np.allclose(fetch_val1, fetch_val2))
def check_multi_card_fetch_var(self): def check_multi_card_fetch_var(self):
if self.is_invalid_test(): if self.is_invalid_test():
...@@ -160,7 +167,8 @@ class InplaceTestBase(unittest.TestCase): ...@@ -160,7 +167,8 @@ class InplaceTestBase(unittest.TestCase):
fetch_vals.append(fetch_val) fetch_vals.append(fetch_val)
for item in fetch_vals: for item in fetch_vals:
self.assertTrue(np.array_equal(fetch_vals[0], item)) # save above
self.assertTrue(np.allclose(fetch_vals[0], item))
class CUDAInplaceTest(InplaceTestBase): class CUDAInplaceTest(InplaceTestBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册