diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index 5ca6a9e05efd833044618a06267ce9ff906e92dd..d10688b12708ec0563cb8d183ba0fff863f20d1e 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -79,7 +79,22 @@ struct Communicator { streams_.resize(gpus.size()); events_.resize(gpus.size()); } - // Communicator(int num_device): comms_.resize(num_device) {} + + ~Communicator() { + for (size_t i = 0; i < gpus_.size(); ++i) { + int gid = gpus_[i]; + platform::SetDeviceId(gid); + + int idx = gid % gpus_.size(); + // wait finish + PADDLE_ENFORCE( + cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); + + PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx])); + + PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx])); + } + } inline int get_root_gpu() const { return root_gpu; } diff --git a/paddle/operators/nccl/nccl_ops.cc b/paddle/operators/nccl/nccl_ops.cc index f1a83c1e1e344841e2d95984400c2c7d4c5c41f7..5cad44dc9fad98fa3e324d8f814b69693e38cc7a 100644 --- a/paddle/operators/nccl/nccl_ops.cc +++ b/paddle/operators/nccl/nccl_ops.cc @@ -14,7 +14,33 @@ namespace paddle { namespace operators { -// AllreduceOp +// NCCLinitOp +class NCCLInitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Communicator"), + " Input(X) of AllReduce op input should not be NULL"); + } +}; + +class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLInitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr>("gpus", "gpu id lists"); + AddOutput("Communicator", + "Create Communicator for communicating between gpus"); + AddComment(R"DOC( + create communicator. + )DOC"); + } +}; + +// AllReduceOp class NCCLAllReduceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -23,6 +49,9 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), " Input(X) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE( + ctx->HasInput("Communicator"), + " Input(Communicator) of AllReduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Input(X) of AllReduce op input should not be NULL"); @@ -45,6 +74,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of AllReduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); AddOutput("Out", "The output of AllReduce op"); AddAttr("reduction", "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}."); @@ -55,31 +85,31 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// BcastSendOp -class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { - public: - NCCLAllReduceOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input of BcastSend op"); - AddComment(R"DOC( - BcastSend the tensors. - )DOC"); - } -}; +// // BcastSendOp +// class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { +// public: +// NCCLAllReduceOpMaker(framework::OpProto *proto, +// framework::OpAttrChecker *op_checker) +// : OpProtoAndCheckerMaker(proto, op_checker) { +// AddInput("X", "The input of BcastSend op"); +// AddComment(R"DOC( +// BcastSend the tensors. +// )DOC"); +// } +// }; -// BcastRecvOp -class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { - public: - NCCLAllReduceOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddOutput("Out", "The output of BcastRecv op"); - AddComment(R"DOC( - BcastRecv the tensors. - )DOC"); - } -}; +// // BcastRecvOp +// class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { +// public: +// NCCLAllReduceOpMaker(framework::OpProto *proto, +// framework::OpAttrChecker *op_checker) +// : OpProtoAndCheckerMaker(proto, op_checker) { +// AddOutput("Out", "The output of BcastRecv op"); +// AddComment(R"DOC( +// BcastRecv the tensors. +// )DOC"); +// } +// }; } // namespace operators } // namespace paddle diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h index c46fdd7d44f789ca178ea5671b7126fe8eaa8d67..a7a74a0e41c073ee5dc77c0f9e449a1fb297a1ed 100644 --- a/paddle/operators/nccl/nccl_ops.h +++ b/paddle/operators/nccl/nccl_ops.h @@ -35,6 +35,16 @@ class NCCLTypeWrapper { static const ncclDataType_t type = ncclDouble; }; +class NCCLInitOp : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto gpus = ctx.Input>("gpus"); + auto* comm = ctx.Output("Communicator"); + comm->mutable_data(CPUPlace()); + comm = NCCLManager::GetCommunicator(gpus); + } +}; + template class NCCLAllReduceKernel : public framework::OpKernel { public: @@ -54,13 +64,15 @@ class NCCLAllReduceKernel : public framework::OpKernel { op_type = ncclMax; } + auto* comm = ctx.Input("Communicator"); + auto dev_ctx = static_cast(ctx.device_context()); - platform::NCCLManager* m = platform::NCCLManager::Get(); + // platform::NCCLManager* m = platform::NCCLManager::Get(); - auto* comm = m->GetCommunicator(gpus); - comm->wg_.Add(1); + // auto* comm = m->GetCommunicator(gpus); + // comm->wg_.Add(1); auto stream = dev_ctx.stream(); @@ -76,14 +88,14 @@ class NCCLAllReduceKernel : public framework::OpKernel { op_type, comm->comms_[idx], comm->streams_[idx])); PADDLE_ENFORCE(cudaEventRecord(comm->events_[idx], comm->streams_[idx])); - // wait finish - PADDLE_ENFORCE( - cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); + // // wait finish + // PADDLE_ENFORCE( + // cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); } - comm->wg_.Done(); + // comm->wg_.Done(); - comm->wg_.Wait(); + // comm->wg_.Wait(); } };