未验证 提交 a6dcaf64 编写于 作者: W Wen Sun 提交者: GitHub

Replace `assert np.close` with `np.testing.assert_allclose` in collective...

Replace `assert np.close` with `np.testing.assert_allclose` in collective communication unittests (#49195)

* refactor: replace `assert` with `assert_allclose`

* chore: add coverage conf

* revert: remove incorrect coverage conf
上级 10d3c096
......@@ -59,7 +59,7 @@ class StreamAllgatherTestCase:
)
if not self._sync_op:
task.wait()
assert np.allclose(
np.testing.assert_allclose(
empty_tensor_list, test_data_list, rtol=1e-05, atol=1e-05
)
......@@ -73,7 +73,7 @@ class StreamAllgatherTestCase:
)
if not self._sync_op:
task.wait()
assert np.allclose(
np.testing.assert_allclose(
full_tensor_list, test_data_list, rtol=1e-05, atol=1e-05
)
......@@ -90,7 +90,9 @@ class StreamAllgatherTestCase:
)
if not self._sync_op:
task.wait()
assert np.allclose(out_tensor, result_tensor, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
out_tensor, result_tensor, rtol=1e-05, atol=1e-05
)
if __name__ == "__main__":
......
......@@ -58,7 +58,7 @@ class StreamAllReduceTestCase:
for i in range(1, len(test_data_list)):
result += test_data_list[i]
assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(tensor, result, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
......
......@@ -75,11 +75,11 @@ class StreamAllToAllTestCase:
task.wait()
result_tensor_list = np.vstack(empty_tensor_list)
if rank == 0:
assert np.allclose(
np.testing.assert_allclose(
result_tensor_list, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(
np.testing.assert_allclose(
result_tensor_list, result2, rtol=1e-05, atol=1e-05
)
......@@ -95,11 +95,11 @@ class StreamAllToAllTestCase:
task.wait()
result_tensor_list = np.vstack(full_tensor_list)
if rank == 0:
assert np.allclose(
np.testing.assert_allclose(
result_tensor_list, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(
np.testing.assert_allclose(
result_tensor_list, result2, rtol=1e-05, atol=1e-05
)
......@@ -114,9 +114,13 @@ class StreamAllToAllTestCase:
if not self._sync_op:
task.wait()
if rank == 0:
assert np.allclose(out_tensor, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
out_tensor, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(out_tensor, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
out_tensor, result2, rtol=1e-05, atol=1e-05
)
if __name__ == "__main__":
......
......@@ -72,9 +72,13 @@ class StreamAllToAllSingleTestCase:
if not self._sync_op:
task.wait()
if rank == 0:
assert np.allclose(out_tensor, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
out_tensor, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(out_tensor, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
out_tensor, result2, rtol=1e-05, atol=1e-05
)
if __name__ == "__main__":
......
......@@ -52,7 +52,7 @@ class StreamBroadcastTestCase:
if not self._sync_op:
task.wait()
assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(tensor, result, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
......
......@@ -59,9 +59,9 @@ class StreamReduceTestCase:
result = sum(test_data_list)
if rank == 1:
assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(tensor, result, rtol=1e-05, atol=1e-05)
else:
assert np.allclose(
np.testing.assert_allclose(
tensor, test_data_list[rank], rtol=1e-05, atol=1e-05
)
......
......@@ -67,9 +67,13 @@ class StreamReduceScatterTestCase:
if not self._sync_op:
task.wait()
if rank == 0:
assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result2, rtol=1e-05, atol=1e-05
)
# case 2: pass a pre-sized tensor
result_tensor = paddle.empty_like(t1)
......@@ -82,9 +86,13 @@ class StreamReduceScatterTestCase:
if not self._sync_op:
task.wait()
if rank == 0:
assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result2, rtol=1e-05, atol=1e-05
)
# case 3: test the legacy API
result_tensor = paddle.empty_like(t1)
......@@ -97,9 +105,13 @@ class StreamReduceScatterTestCase:
if not self._sync_op:
task.wait()
if rank == 0:
assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result1, rtol=1e-05, atol=1e-05
)
else:
assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(
result_tensor, result2, rtol=1e-05, atol=1e-05
)
if __name__ == "__main__":
......
......@@ -66,9 +66,9 @@ class StreamScatterTestCase:
if not self._sync_op:
task.wait()
if rank == src_rank:
assert np.allclose(t1, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(t1, result2, rtol=1e-05, atol=1e-05)
else:
assert np.allclose(t1, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(t1, result1, rtol=1e-05, atol=1e-05)
# case 2: pass a pre-sized tensor
tensor = paddle.to_tensor(src_data)
......@@ -83,9 +83,9 @@ class StreamScatterTestCase:
if not self._sync_op:
task.wait()
if rank == src_rank:
assert np.allclose(t1, result2, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(t1, result2, rtol=1e-05, atol=1e-05)
else:
assert np.allclose(t1, result1, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(t1, result1, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
......
......@@ -69,7 +69,7 @@ class StreamSendRecvTestCase:
task.wait()
result = test_data_list[src_rank]
assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05)
np.testing.assert_allclose(tensor, result, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
......
......@@ -396,7 +396,7 @@ class TestDistBase(unittest.TestCase):
for i in range(result_data.shape[0]):
for j in range(result_data.shape[1]):
data = result_data[i][j]
assert np.allclose(
np.testing.assert_allclose(
tr0_out[1][i][j], need_result[data], atol=1e-08
)
elif col_type == "row_parallel_linear":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册