未验证 提交 e450823b 编写于 作者: W WangXi 提交者: GitHub

Fix nccl op test failed, test=develop (#28172)

上级 ee4309e6
...@@ -174,10 +174,11 @@ void NCCLTester::testNcclAllReduceOp() { ...@@ -174,10 +174,11 @@ void NCCLTester::testNcclAllReduceOp() {
result_tensor->Resize(kDims); result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place); auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy( auto *dev_ctx = static_cast<p::CUDADeviceContext *>(dev_ctxs_[i]);
cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt, paddle::memory::Copy(cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt,
recv_tensor.numel() * sizeof(float), recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs_[i])->stream()); dev_ctx->stream());
dev_ctx->Wait();
for (int64_t j = 0; j < f::product(kDims); ++j) { for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], expected_result, 1e-5); ASSERT_NEAR(ct[j], expected_result, 1e-5);
...@@ -272,10 +273,10 @@ void NCCLTester::testNcclBcastOp() { ...@@ -272,10 +273,10 @@ void NCCLTester::testNcclBcastOp() {
result_tensor->Resize(kDims); result_tensor->Resize(kDims);
auto *ct = result_tensor->mutable_data<float>(cpu_place); auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy( auto *dev_ctx = static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx]);
cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt, paddle::memory::Copy(cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt,
recv_tensor.numel() * sizeof(float), recv_tensor.numel() * sizeof(float), dev_ctx->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx])->stream()); dev_ctx->Wait();
for (int64_t j = 0; j < f::product(kDims); ++j) { for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5); ASSERT_NEAR(ct[j], result, 1e-5);
...@@ -288,13 +289,9 @@ TEST_F(NCCLTester, ncclInitOp) {} ...@@ -288,13 +289,9 @@ TEST_F(NCCLTester, ncclInitOp) {}
TEST_F(NCCLTester, ncclOp) { TEST_F(NCCLTester, ncclOp) {
// Serial execution is required for the same nccl comm. // Serial execution is required for the same nccl comm.
// ncclAllReduceOp with desc
// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9367
testNcclReduceOp(); testNcclReduceOp();
testNcclAllReduceOp(); testNcclAllReduceOp();
// ncclBcastOp with desc
// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9540
testNcclBcastOp(); testNcclBcastOp();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册