diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index ec7a89d5ff4a66e0527d8365027bc2c5b4e93aaf..5b6c9bec70178f8bc7cb26011ee236a562b22ed0 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -93,6 +93,10 @@ class NCCLReduceOp : public framework::OperatorWithKernel { " Input(Communicator) of Reduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Input(X) of Reduce op input should not be NULL"); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu index 8c54a3dcba2d3a15815a0e528ac1c7b44809c5f3..0eda0c6b57d5091546a4aba33860fdeb94abca4f 100644 --- a/paddle/operators/nccl_op_test.cu +++ b/paddle/operators/nccl_op_test.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -150,16 +151,41 @@ TEST_F(NCCLTester, ncclAllReduceOp) { op2->SetInput("Communicator", {"comm"}); op2->SetOutput("Out", {"rt"}); + std::vector dev_scopes; + std::vector ths; + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], - *op2.get(), &g_scope.NewScope()); + *op2.get(), dev_scopes[i]); ths.emplace_back(std::move(th)); } for (size_t i = 0; i < gpu_list.size(); ++i) { ths[i].join(); } + + // check results + float result = 0; + std::accumulate(gpu_list.begin(), gpu_list.end(), result); + for (size_t i = 0; i < dev_scopes.size(); ++i) { + auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + + p::CPUPlace cpu_place; + auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[i]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[i])->stream()); + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } + } } // ncclReduceOp with desc @@ -170,24 +196,76 @@ TEST(NCCL, ncclReduceOp) { op2->SetInput("Communicator", {"comm"}); op2->SetOutput("Out", {"rt"}); + std::vector dev_scopes; + std::vector ths; for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], - *op2.get(), &g_scope.NewScope()); + *op2.get(), dev_scopes[i]); ths.emplace_back(std::move(th)); } for (size_t i = 0; i < gpu_list.size(); ++i) { ths[i].join(); } + + // check results + float result = 0; + std::accumulate(gpu_list.begin(), gpu_list.end(), result); + for (size_t i = 0; i < dev_scopes.size(); ++i) { + auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + + p::CPUPlace cpu_place; + auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[i]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[i])->stream()); + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } + } } // ncclBcastOp with desc -// TEST(NCCL, ncclBcastOp) { +TEST(NCCL, ncclBcastOp) { + std::unique_ptr op1(new f::OpDescBind); + op1->SetType("ncclBcastSend"); + op1->SetInput("X", {"st"}); + op1->SetInput("Communicator", {"comm"}); + + std::unique_ptr op2(new f::OpDescBind); + op2->SetType("ncclBcastRecv"); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + + std::vector ths; + for (size_t i = 1; i < gpu_list.size(); ++i) { + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), &g_scope.NewScope()); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } +} + +// joint ncclBcastOp and ncclReduceOp +// TEST(NCCL, MultipleOp) { // std::unique_ptr op2(new f::OpDescBind); // op2->SetType("ncclBcastSend"); // op2->SetInput("X", {"st"}); // op2->SetInput("Communicator", {"comm"}); + +// std::unique_ptr op2(new f::OpDescBind); +// op2->SetType("ncclBcastRecv"); +// op2->SetInput("Communicator", {"comm"}); // op2->SetOutput("Out", {"rt"}); // std::vector ths;